259 lines
8.1 KiB
Go
259 lines
8.1 KiB
Go
// Copyright 2024 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package join
|
|
|
|
import (
|
|
"sync/atomic"
|
|
"unsafe"
|
|
)
|
|
|
|
type subTable struct {
|
|
rowData *rowTable
|
|
// the taggedPtr is used to save the row address, during hash join build stage
|
|
// it will convert the chunk data into row format, each row there is an unsafe.Pointer
|
|
// pointing the start address of the row. The unsafe.Pointer will be converted to
|
|
// taggedPtr and saved in hashTable.
|
|
// Generally speaking it is unsafe or even illegal in go to save unsafe.Pointer
|
|
// into uintptr, and later convert uintptr back to unsafe.Pointer since after save
|
|
// the value of unsafe.Pointer into uintptr, it has no pointer semantics, and may
|
|
// become invalid after GC. But it is ok to do this in hash join so far because
|
|
// 1. the check of heapObjectsCanMove makes sure that if the object is in heap, the address will not be changed after GC
|
|
// 2. row address only points to a valid address in `rowTableSegment.rawData`. `rawData` is a slice in `rowTableSegment`, and it will be used by multiple goroutines,
|
|
// and its size will be runtime expanded, this kind of slice will always be allocated in heap
|
|
hashTable []taggedPtr
|
|
posMask uint64
|
|
isRowTableEmpty bool
|
|
isHashTableEmpty bool
|
|
}
|
|
|
|
func (st *subTable) getTotalMemoryUsage() int64 {
|
|
return st.rowData.getTotalMemoryUsage() + getHashTableMemoryUsage(uint64(len(st.hashTable)))
|
|
}
|
|
|
|
func (st *subTable) lookup(hashValue uint64, tagHelper *tagPtrHelper) taggedPtr {
|
|
ret := st.hashTable[hashValue&st.posMask]
|
|
hashTagValue := tagHelper.getTaggedValue(hashValue)
|
|
if uint64(ret)&hashTagValue != hashTagValue {
|
|
// if tag value not match, the key will not be matched
|
|
return 0
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func nextPowerOfTwo(value uint64) uint64 {
|
|
ret := uint64(2)
|
|
round := 1
|
|
for ; ret <= value && round <= 64; ret = ret << 1 {
|
|
round++
|
|
}
|
|
if round > 64 {
|
|
panic("input value is too large")
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func newSubTable(table *rowTable) *subTable {
|
|
ret := &subTable{
|
|
rowData: table,
|
|
isHashTableEmpty: false,
|
|
isRowTableEmpty: false,
|
|
}
|
|
if table.rowCount() == 0 {
|
|
ret.isRowTableEmpty = true
|
|
}
|
|
if table.validKeyCount() == 0 {
|
|
ret.isHashTableEmpty = true
|
|
}
|
|
hashTableLength := max(nextPowerOfTwo(table.validKeyCount()), uint64(32))
|
|
ret.hashTable = make([]taggedPtr, hashTableLength)
|
|
ret.posMask = hashTableLength - 1
|
|
return ret
|
|
}
|
|
|
|
func (st *subTable) updateHashValue(hashValue uint64, rowAddress unsafe.Pointer, tagHelper *tagPtrHelper) {
|
|
pos := hashValue & st.posMask
|
|
prev := st.hashTable[pos]
|
|
tagValue := tagHelper.getTaggedValue(hashValue | uint64(prev))
|
|
taggedAddress := tagHelper.toTaggedPtr(tagValue, rowAddress)
|
|
st.hashTable[pos] = taggedAddress
|
|
setNextRowAddress(rowAddress, prev)
|
|
}
|
|
|
|
func (st *subTable) atomicUpdateHashValue(hashValue uint64, rowAddress unsafe.Pointer, tagHelper *tagPtrHelper) {
|
|
pos := hashValue & st.posMask
|
|
for {
|
|
prev := taggedPtr(atomic.LoadUintptr((*uintptr)(unsafe.Pointer(&st.hashTable[pos]))))
|
|
tagValue := tagHelper.getTaggedValue(hashValue | uint64(prev))
|
|
taggedAddress := tagHelper.toTaggedPtr(tagValue, rowAddress)
|
|
if atomic.CompareAndSwapUintptr((*uintptr)(unsafe.Pointer(&st.hashTable[pos])), uintptr(prev), uintptr(taggedAddress)) {
|
|
setNextRowAddress(rowAddress, prev)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func (st *subTable) build(startSegmentIndex int, endSegmentIndex int, tagHelper *tagPtrHelper) {
|
|
if startSegmentIndex == 0 && endSegmentIndex == len(st.rowData.segments) {
|
|
for i := startSegmentIndex; i < endSegmentIndex; i++ {
|
|
for _, index := range st.rowData.segments[i].validJoinKeyPos {
|
|
rowAddress := st.rowData.segments[i].getRowPointer(index)
|
|
hashValue := st.rowData.segments[i].hashValues[index]
|
|
st.updateHashValue(hashValue, rowAddress, tagHelper)
|
|
}
|
|
}
|
|
} else {
|
|
for i := startSegmentIndex; i < endSegmentIndex; i++ {
|
|
for _, index := range st.rowData.segments[i].validJoinKeyPos {
|
|
rowAddress := st.rowData.segments[i].getRowPointer(index)
|
|
hashValue := st.rowData.segments[i].hashValues[index]
|
|
st.atomicUpdateHashValue(hashValue, rowAddress, tagHelper)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type hashTableV2 struct {
|
|
tables []*subTable
|
|
partitionNumber uint64
|
|
}
|
|
|
|
func (ht *hashTableV2) getPartitionMemoryUsage(partID int) int64 {
|
|
if ht.tables[partID] != nil {
|
|
return ht.tables[partID].getTotalMemoryUsage()
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (ht *hashTableV2) clearPartitionSegments(partID int) {
|
|
if ht.tables[partID] != nil {
|
|
ht.tables[partID].rowData.clearSegments()
|
|
ht.tables[partID].hashTable = nil
|
|
}
|
|
}
|
|
|
|
type rowPos struct {
|
|
subTableIndex int
|
|
rowSegmentIndex int
|
|
rowIndex uint64
|
|
}
|
|
|
|
type rowIter struct {
|
|
table *hashTableV2
|
|
currentPos *rowPos
|
|
endPos *rowPos
|
|
}
|
|
|
|
func (ri *rowIter) getValue() unsafe.Pointer {
|
|
return ri.table.tables[ri.currentPos.subTableIndex].rowData.segments[ri.currentPos.rowSegmentIndex].getRowPointer(int(ri.currentPos.rowIndex))
|
|
}
|
|
|
|
func (ri *rowIter) next() {
|
|
ri.currentPos.rowIndex++
|
|
if ri.currentPos.rowIndex == uint64(ri.table.tables[ri.currentPos.subTableIndex].rowData.segments[ri.currentPos.rowSegmentIndex].rowCount()) {
|
|
ri.currentPos.rowSegmentIndex++
|
|
ri.currentPos.rowIndex = 0
|
|
for ri.currentPos.rowSegmentIndex == len(ri.table.tables[ri.currentPos.subTableIndex].rowData.segments) {
|
|
ri.currentPos.subTableIndex++
|
|
ri.currentPos.rowSegmentIndex = 0
|
|
if ri.currentPos.subTableIndex == int(ri.table.partitionNumber) {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ri *rowIter) isEnd() bool {
|
|
return !(ri.currentPos.subTableIndex < ri.endPos.subTableIndex || ri.currentPos.rowSegmentIndex < ri.endPos.rowSegmentIndex || ri.currentPos.rowIndex < ri.endPos.rowIndex)
|
|
}
|
|
|
|
func newJoinHashTableForTest(partitionedRowTables []*rowTable) *hashTableV2 {
|
|
// first make sure there is no nil rowTable
|
|
jht := &hashTableV2{
|
|
tables: make([]*subTable, len(partitionedRowTables)),
|
|
partitionNumber: uint64(len(partitionedRowTables)),
|
|
}
|
|
for i, rowTable := range partitionedRowTables {
|
|
jht.tables[i] = newSubTable(rowTable)
|
|
}
|
|
return jht
|
|
}
|
|
|
|
func (ht *hashTableV2) createRowPos(pos uint64) *rowPos {
|
|
if pos > ht.totalRowCount() {
|
|
panic("invalid call to createRowPos, the input pos should be in [0, totalRowCount]")
|
|
}
|
|
if pos == ht.totalRowCount() {
|
|
return &rowPos{
|
|
subTableIndex: len(ht.tables),
|
|
rowSegmentIndex: 0,
|
|
rowIndex: 0,
|
|
}
|
|
}
|
|
subTableIndex := 0
|
|
for pos >= ht.tables[subTableIndex].rowData.rowCount() {
|
|
pos -= ht.tables[subTableIndex].rowData.rowCount()
|
|
subTableIndex++
|
|
}
|
|
rowSegmentIndex := 0
|
|
for pos >= uint64(ht.tables[subTableIndex].rowData.segments[rowSegmentIndex].rowCount()) {
|
|
pos -= uint64(ht.tables[subTableIndex].rowData.segments[rowSegmentIndex].rowCount())
|
|
rowSegmentIndex++
|
|
}
|
|
return &rowPos{
|
|
subTableIndex: subTableIndex,
|
|
rowSegmentIndex: rowSegmentIndex,
|
|
rowIndex: pos,
|
|
}
|
|
}
|
|
func (ht *hashTableV2) createRowIter(start, end uint64) *rowIter {
|
|
if start > end {
|
|
start = end
|
|
}
|
|
return &rowIter{
|
|
table: ht,
|
|
currentPos: ht.createRowPos(start),
|
|
endPos: ht.createRowPos(end),
|
|
}
|
|
}
|
|
|
|
func (ht *hashTableV2) isHashTableEmpty() bool {
|
|
for _, subTable := range ht.tables {
|
|
if !subTable.isHashTableEmpty {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (ht *hashTableV2) totalRowCount() uint64 {
|
|
ret := uint64(0)
|
|
for _, table := range ht.tables {
|
|
ret += table.rowData.rowCount()
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func getHashTableLengthByRowTable(table *rowTable) uint64 {
|
|
return getHashTableLengthByRowLen(table.validKeyCount())
|
|
}
|
|
|
|
func getHashTableLengthByRowLen(rowLen uint64) uint64 {
|
|
return max(nextPowerOfTwo(rowLen), uint64(minimalHashTableLen))
|
|
}
|
|
|
|
func getHashTableMemoryUsage(hashTableLength uint64) int64 {
|
|
return int64(hashTableLength) * taggedPointerLen
|
|
}
|