Files
tidb/pkg/executor/join/hash_table_v2.go

259 lines
8.1 KiB
Go

// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package join
import (
"sync/atomic"
"unsafe"
)
type subTable struct {
rowData *rowTable
// the taggedPtr is used to save the row address, during hash join build stage
// it will convert the chunk data into row format, each row there is an unsafe.Pointer
// pointing the start address of the row. The unsafe.Pointer will be converted to
// taggedPtr and saved in hashTable.
// Generally speaking it is unsafe or even illegal in go to save unsafe.Pointer
// into uintptr, and later convert uintptr back to unsafe.Pointer since after save
// the value of unsafe.Pointer into uintptr, it has no pointer semantics, and may
// become invalid after GC. But it is ok to do this in hash join so far because
// 1. the check of heapObjectsCanMove makes sure that if the object is in heap, the address will not be changed after GC
// 2. row address only points to a valid address in `rowTableSegment.rawData`. `rawData` is a slice in `rowTableSegment`, and it will be used by multiple goroutines,
// and its size will be runtime expanded, this kind of slice will always be allocated in heap
hashTable []taggedPtr
posMask uint64
isRowTableEmpty bool
isHashTableEmpty bool
}
func (st *subTable) getTotalMemoryUsage() int64 {
return st.rowData.getTotalMemoryUsage() + getHashTableMemoryUsage(uint64(len(st.hashTable)))
}
func (st *subTable) lookup(hashValue uint64, tagHelper *tagPtrHelper) taggedPtr {
ret := st.hashTable[hashValue&st.posMask]
hashTagValue := tagHelper.getTaggedValue(hashValue)
if uint64(ret)&hashTagValue != hashTagValue {
// if tag value not match, the key will not be matched
return 0
}
return ret
}
func nextPowerOfTwo(value uint64) uint64 {
ret := uint64(2)
round := 1
for ; ret <= value && round <= 64; ret = ret << 1 {
round++
}
if round > 64 {
panic("input value is too large")
}
return ret
}
func newSubTable(table *rowTable) *subTable {
ret := &subTable{
rowData: table,
isHashTableEmpty: false,
isRowTableEmpty: false,
}
if table.rowCount() == 0 {
ret.isRowTableEmpty = true
}
if table.validKeyCount() == 0 {
ret.isHashTableEmpty = true
}
hashTableLength := max(nextPowerOfTwo(table.validKeyCount()), uint64(32))
ret.hashTable = make([]taggedPtr, hashTableLength)
ret.posMask = hashTableLength - 1
return ret
}
func (st *subTable) updateHashValue(hashValue uint64, rowAddress unsafe.Pointer, tagHelper *tagPtrHelper) {
pos := hashValue & st.posMask
prev := st.hashTable[pos]
tagValue := tagHelper.getTaggedValue(hashValue | uint64(prev))
taggedAddress := tagHelper.toTaggedPtr(tagValue, rowAddress)
st.hashTable[pos] = taggedAddress
setNextRowAddress(rowAddress, prev)
}
func (st *subTable) atomicUpdateHashValue(hashValue uint64, rowAddress unsafe.Pointer, tagHelper *tagPtrHelper) {
pos := hashValue & st.posMask
for {
prev := taggedPtr(atomic.LoadUintptr((*uintptr)(unsafe.Pointer(&st.hashTable[pos]))))
tagValue := tagHelper.getTaggedValue(hashValue | uint64(prev))
taggedAddress := tagHelper.toTaggedPtr(tagValue, rowAddress)
if atomic.CompareAndSwapUintptr((*uintptr)(unsafe.Pointer(&st.hashTable[pos])), uintptr(prev), uintptr(taggedAddress)) {
setNextRowAddress(rowAddress, prev)
break
}
}
}
func (st *subTable) build(startSegmentIndex int, endSegmentIndex int, tagHelper *tagPtrHelper) {
if startSegmentIndex == 0 && endSegmentIndex == len(st.rowData.segments) {
for i := startSegmentIndex; i < endSegmentIndex; i++ {
for _, index := range st.rowData.segments[i].validJoinKeyPos {
rowAddress := st.rowData.segments[i].getRowPointer(index)
hashValue := st.rowData.segments[i].hashValues[index]
st.updateHashValue(hashValue, rowAddress, tagHelper)
}
}
} else {
for i := startSegmentIndex; i < endSegmentIndex; i++ {
for _, index := range st.rowData.segments[i].validJoinKeyPos {
rowAddress := st.rowData.segments[i].getRowPointer(index)
hashValue := st.rowData.segments[i].hashValues[index]
st.atomicUpdateHashValue(hashValue, rowAddress, tagHelper)
}
}
}
}
type hashTableV2 struct {
tables []*subTable
partitionNumber uint64
}
func (ht *hashTableV2) getPartitionMemoryUsage(partID int) int64 {
if ht.tables[partID] != nil {
return ht.tables[partID].getTotalMemoryUsage()
}
return 0
}
func (ht *hashTableV2) clearPartitionSegments(partID int) {
if ht.tables[partID] != nil {
ht.tables[partID].rowData.clearSegments()
ht.tables[partID].hashTable = nil
}
}
type rowPos struct {
subTableIndex int
rowSegmentIndex int
rowIndex uint64
}
type rowIter struct {
table *hashTableV2
currentPos *rowPos
endPos *rowPos
}
func (ri *rowIter) getValue() unsafe.Pointer {
return ri.table.tables[ri.currentPos.subTableIndex].rowData.segments[ri.currentPos.rowSegmentIndex].getRowPointer(int(ri.currentPos.rowIndex))
}
func (ri *rowIter) next() {
ri.currentPos.rowIndex++
if ri.currentPos.rowIndex == uint64(ri.table.tables[ri.currentPos.subTableIndex].rowData.segments[ri.currentPos.rowSegmentIndex].rowCount()) {
ri.currentPos.rowSegmentIndex++
ri.currentPos.rowIndex = 0
for ri.currentPos.rowSegmentIndex == len(ri.table.tables[ri.currentPos.subTableIndex].rowData.segments) {
ri.currentPos.subTableIndex++
ri.currentPos.rowSegmentIndex = 0
if ri.currentPos.subTableIndex == int(ri.table.partitionNumber) {
break
}
}
}
}
func (ri *rowIter) isEnd() bool {
return !(ri.currentPos.subTableIndex < ri.endPos.subTableIndex || ri.currentPos.rowSegmentIndex < ri.endPos.rowSegmentIndex || ri.currentPos.rowIndex < ri.endPos.rowIndex)
}
func newJoinHashTableForTest(partitionedRowTables []*rowTable) *hashTableV2 {
// first make sure there is no nil rowTable
jht := &hashTableV2{
tables: make([]*subTable, len(partitionedRowTables)),
partitionNumber: uint64(len(partitionedRowTables)),
}
for i, rowTable := range partitionedRowTables {
jht.tables[i] = newSubTable(rowTable)
}
return jht
}
func (ht *hashTableV2) createRowPos(pos uint64) *rowPos {
if pos > ht.totalRowCount() {
panic("invalid call to createRowPos, the input pos should be in [0, totalRowCount]")
}
if pos == ht.totalRowCount() {
return &rowPos{
subTableIndex: len(ht.tables),
rowSegmentIndex: 0,
rowIndex: 0,
}
}
subTableIndex := 0
for pos >= ht.tables[subTableIndex].rowData.rowCount() {
pos -= ht.tables[subTableIndex].rowData.rowCount()
subTableIndex++
}
rowSegmentIndex := 0
for pos >= uint64(ht.tables[subTableIndex].rowData.segments[rowSegmentIndex].rowCount()) {
pos -= uint64(ht.tables[subTableIndex].rowData.segments[rowSegmentIndex].rowCount())
rowSegmentIndex++
}
return &rowPos{
subTableIndex: subTableIndex,
rowSegmentIndex: rowSegmentIndex,
rowIndex: pos,
}
}
func (ht *hashTableV2) createRowIter(start, end uint64) *rowIter {
if start > end {
start = end
}
return &rowIter{
table: ht,
currentPos: ht.createRowPos(start),
endPos: ht.createRowPos(end),
}
}
func (ht *hashTableV2) isHashTableEmpty() bool {
for _, subTable := range ht.tables {
if !subTable.isHashTableEmpty {
return false
}
}
return true
}
func (ht *hashTableV2) totalRowCount() uint64 {
ret := uint64(0)
for _, table := range ht.tables {
ret += table.rowData.rowCount()
}
return ret
}
func getHashTableLengthByRowTable(table *rowTable) uint64 {
return getHashTableLengthByRowLen(table.validKeyCount())
}
func getHashTableLengthByRowLen(rowLen uint64) uint64 {
return max(nextPowerOfTwo(rowLen), uint64(minimalHashTableLen))
}
func getHashTableMemoryUsage(hashTableLength uint64) int64 {
return int64(hashTableLength) * taggedPointerLen
}