diff --git a/executor/distsql.go b/executor/distsql.go index b20ac8f8ee..c35fcb073b 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -44,9 +44,13 @@ const ( func resultRowToRow(t table.Table, h int64, data []types.Datum, tableAsName *model.CIStr) *Row { entry := &RowKeyEntry{ - Handle: h, - Tbl: t, - TableAsName: tableAsName, + Handle: h, + Tbl: t, + } + if tableAsName != nil && tableAsName.L != "" { + entry.TableName = tableAsName.L + } else { + entry.TableName = t.Meta().Name.L } return &Row{Data: data, RowKeys: []*RowKeyEntry{entry}} } diff --git a/executor/executor.go b/executor/executor.go index fe7ea3c58c..7116949d4b 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -100,7 +100,7 @@ type RowKeyEntry struct { // Row key. Handle int64 // Table alias name. - TableAsName *model.CIStr + TableName string } // Executor executes a query. @@ -691,9 +691,13 @@ func (e *TableScanExec) getRow(handle int64) (*Row, error) { // Put rowKey to the tail of record row. rke := &RowKeyEntry{ - Tbl: e.t, - Handle: handle, - TableAsName: e.asName, + Tbl: e.t, + Handle: handle, + } + if e.asName != nil && e.asName.L != "" { + rke.TableName = e.asName.L + } else { + rke.TableName = e.t.Meta().Name.L } row.RowKeys = append(row.RowKeys, rke) return row, nil diff --git a/executor/join.go b/executor/join.go index 772e9c63ee..4139b076ad 100644 --- a/executor/join.go +++ b/executor/join.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/mvmap" "github.com/pingcap/tidb/util/types" ) @@ -34,7 +35,7 @@ var ( // HashJoinExec implements the hash join algorithm. type HashJoinExec struct { - hashTable map[string][]*Row + hashTable *mvmap.MVMap smallHashKey []*expression.Column bigHashKey []*expression.Column smallExec Executor @@ -66,6 +67,10 @@ type HashJoinExec struct { // Channels for output. resultCh chan *execResult + + // rowKeyCache is used to store the table and table name from a row. + // Because every row has the same table name and table, we can use a single row key cache. + rowKeyCache []*RowKeyEntry } // hashJoinCtx holds the variables needed to do a hash join in one of many concurrent goroutines. @@ -104,9 +109,9 @@ func makeJoinRow(a *Row, b *Row) *Row { return ret } -// getHashKey gets the hash key when given a row and hash columns. +// getJoinKey gets the hash key when given a row and hash columns. // It will return a boolean value representing if the hash key has null, a byte slice representing the result hash code. -func getHashKey(sc *variable.StatementContext, cols []*expression.Column, row *Row, targetTypes []*types.FieldType, +func getJoinKey(sc *variable.StatementContext, cols []*expression.Column, row *Row, targetTypes []*types.FieldType, vals []types.Datum, bytes []byte) (bool, []byte, error) { var err error for i, col := range cols { @@ -208,9 +213,10 @@ func (e *HashJoinExec) prepare() error { e.wg.Add(1) go e.fetchBigExec() - e.hashTable = make(map[string][]*Row) + e.hashTable = mvmap.NewMVMap() e.cursor = 0 sc := e.ctx.GetSessionVars().StmtCtx + var buffer []byte for { row, err := e.smallExec.Next() if err != nil { @@ -231,18 +237,19 @@ func (e *HashJoinExec) prepare() error { continue } } - hasNull, hashcode, err := getHashKey(sc, e.smallHashKey, row, e.targetTypes, e.hashJoinContexts[0].datumBuffer, nil) + hasNull, joinKey, err := getJoinKey(sc, e.smallHashKey, row, e.targetTypes, e.hashJoinContexts[0].datumBuffer, nil) if err != nil { return errors.Trace(err) } if hasNull { continue } - if rows, ok := e.hashTable[string(hashcode)]; !ok { - e.hashTable[string(hashcode)] = []*Row{row} - } else { - e.hashTable[string(hashcode)] = append(rows, row) + buffer = buffer[:0] + buffer, err = e.encodeRow(buffer, row) + if err != nil { + return errors.Trace(err) } + e.hashTable.Put(joinKey, buffer) } e.resultCh = make(chan *execResult, e.concurrency) @@ -257,6 +264,54 @@ func (e *HashJoinExec) prepare() error { return nil } +func (e *HashJoinExec) encodeRow(b []byte, row *Row) ([]byte, error) { + numRowKeys := int64(len(row.RowKeys)) + b = codec.EncodeVarint(b, numRowKeys) + for _, rowKey := range row.RowKeys { + b = codec.EncodeVarint(b, rowKey.Handle) + } + if numRowKeys > 0 && e.rowKeyCache == nil { + e.rowKeyCache = make([]*RowKeyEntry, len(row.RowKeys)) + for i := 0; i < len(row.RowKeys); i++ { + rk := new(RowKeyEntry) + rk.Tbl = row.RowKeys[i].Tbl + rk.TableName = row.RowKeys[i].TableName + e.rowKeyCache[i] = rk + } + } + b, err := codec.EncodeValue(b, row.Data...) + return b, errors.Trace(err) +} + +func (e *HashJoinExec) decodeRow(data []byte) (*Row, error) { + row := new(Row) + data, entryLen, err := codec.DecodeVarint(data) + if err != nil { + return nil, errors.Trace(err) + } + for i := 0; i < int(entryLen); i++ { + entry := new(RowKeyEntry) + data, entry.Handle, err = codec.DecodeVarint(data) + if err != nil { + return nil, errors.Trace(err) + } + entry.Tbl = e.rowKeyCache[i].Tbl + entry.TableName = e.rowKeyCache[i].TableName + row.RowKeys = append(row.RowKeys, entry) + } + values := make([]types.Datum, e.smallExec.Schema().Len()) + err = codec.SetRawValues(data, values) + if err != nil { + return nil, errors.Trace(err) + } + err = decodeRawValues(values, e.smallExec.Schema()) + if err != nil { + return nil, errors.Trace(err) + } + row.Data = values + return row, nil +} + func (e *HashJoinExec) waitJoinWorkersAndCloseResultChan() { e.wg.Wait() close(e.resultCh) @@ -341,7 +396,7 @@ func (e *HashJoinExec) joinOneBigRow(ctx *hashJoinCtx, bigRow *Row, result *exec // constructMatchedRows creates matching result rows from a row in the big table. func (e *HashJoinExec) constructMatchedRows(ctx *hashJoinCtx, bigRow *Row) (matchedRows []*Row, err error) { sc := e.ctx.GetSessionVars().StmtCtx - hasNull, hashcode, err := getHashKey(sc, e.bigHashKey, bigRow, e.targetTypes, ctx.datumBuffer, ctx.hashKeyBuffer[0:0:cap(ctx.hashKeyBuffer)]) + hasNull, joinKey, err := getJoinKey(sc, e.bigHashKey, bigRow, e.targetTypes, ctx.datumBuffer, ctx.hashKeyBuffer[0:0:cap(ctx.hashKeyBuffer)]) if err != nil { return nil, errors.Trace(err) } @@ -349,12 +404,17 @@ func (e *HashJoinExec) constructMatchedRows(ctx *hashJoinCtx, bigRow *Row) (matc if hasNull { return } - rows, ok := e.hashTable[string(hashcode)] - if !ok { + values := e.hashTable.Get(joinKey) + if len(values) == 0 { return } // match eq condition - for _, smallRow := range rows { + for _, value := range values { + var smallRow *Row + smallRow, err = e.decodeRow(value) + if err != nil { + return nil, errors.Trace(err) + } otherMatched := true var matchedRow *Row if e.leftSmall { @@ -372,7 +432,6 @@ func (e *HashJoinExec) constructMatchedRows(ctx *hashJoinCtx, bigRow *Row) (matc matchedRows = append(matchedRows, matchedRow) } } - return matchedRows, nil } @@ -662,7 +721,7 @@ func (e *HashSemiJoinExec) prepare() error { continue } } - hasNull, hashcode, err := getHashKey(sc, e.smallHashKey, row, e.targetTypes, make([]types.Datum, len(e.smallHashKey)), nil) + hasNull, hashcode, err := getJoinKey(sc, e.smallHashKey, row, e.targetTypes, make([]types.Datum, len(e.smallHashKey)), nil) if err != nil { return errors.Trace(err) } @@ -683,7 +742,7 @@ func (e *HashSemiJoinExec) prepare() error { func (e *HashSemiJoinExec) rowIsMatched(bigRow *Row) (matched bool, hasNull bool, err error) { sc := e.ctx.GetSessionVars().StmtCtx - hasNull, hashcode, err := getHashKey(sc, e.bigHashKey, bigRow, e.targetTypes, make([]types.Datum, len(e.smallHashKey)), nil) + hasNull, hashcode, err := getJoinKey(sc, e.bigHashKey, bigRow, e.targetTypes, make([]types.Datum, len(e.smallHashKey)), nil) if err != nil { return false, false, errors.Trace(err) } diff --git a/executor/union_scan.go b/executor/union_scan.go index f50f41b1ff..88f0dcdc1b 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -258,7 +258,13 @@ func (us *UnionScanExec) buildAndSortAddedRows(t table.Table, asName *model.CISt continue } } - rowKeyEntry := &RowKeyEntry{Handle: h, Tbl: t, TableAsName: asName} + rowKeyEntry := &RowKeyEntry{Handle: h, Tbl: t} + if asName != nil && asName.L != "" { + rowKeyEntry.TableName = asName.L + } else { + rowKeyEntry.TableName = t.Meta().Name.L + } + row := &Row{Data: newData, RowKeys: []*RowKeyEntry{rowKeyEntry}} us.addedRows = append(us.addedRows, row) } diff --git a/executor/write.go b/executor/write.go index 5861d25982..2bd9d9c653 100644 --- a/executor/write.go +++ b/executor/write.go @@ -214,20 +214,12 @@ func (e *DeleteExec) deleteMultiTables() error { } func isMatchTableName(entry *RowKeyEntry, tblMap map[int64][]string) bool { - var name string - if entry.TableAsName != nil { - name = entry.TableAsName.L - } - if len(name) == 0 { - name = entry.Tbl.Meta().Name.L - } - names, ok := tblMap[entry.Tbl.Meta().ID] if !ok { return false } for _, n := range names { - if n == name { + if n == entry.TableName { return true } } @@ -1202,16 +1194,9 @@ func (e *UpdateExec) fetchRows() error { } func getTableOffset(schema *expression.Schema, entry *RowKeyEntry) int { - t := entry.Tbl - var tblName string - if entry.TableAsName == nil || len(entry.TableAsName.L) == 0 { - tblName = t.Meta().Name.L - } else { - tblName = entry.TableAsName.L - } for i := 0; i < schema.Len(); i++ { s := schema.Columns[i] - if s.TblName.L == tblName { + if s.TblName.L == entry.TableName { return i } } diff --git a/util/mvmap/mvmap.go b/util/mvmap/mvmap.go new file mode 100644 index 0000000000..b1c1c5f461 --- /dev/null +++ b/util/mvmap/mvmap.go @@ -0,0 +1,167 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mvmap + +import ( + "bytes" + "hash" + "hash/fnv" +) + +type entry struct { + addr dataAddr + keyLen uint32 + valLen uint32 + next entryAddr +} + +type entryStore struct { + slices [][]entry + sliceIdx uint32 + sliceLen uint32 +} + +type dataStore struct { + slices [][]byte + sliceIdx uint32 + sliceLen uint32 +} + +type entryAddr struct { + sliceIdx uint32 + offset uint32 +} + +type dataAddr struct { + sliceIdx uint32 + offset uint32 +} + +const ( + maxDataSliceLen = 64 * 1024 + maxEntrySliceLen = 8 * 1024 +) + +func (ds *dataStore) put(key, value []byte) dataAddr { + dataLen := uint32(len(key) + len(value)) + if ds.sliceLen != 0 && ds.sliceLen+dataLen > maxDataSliceLen { + ds.slices = append(ds.slices, make([]byte, 0, max(maxDataSliceLen, int(dataLen)))) + ds.sliceLen = 0 + ds.sliceIdx++ + } + addr := dataAddr{sliceIdx: ds.sliceIdx, offset: ds.sliceLen} + slice := ds.slices[ds.sliceIdx] + slice = append(slice, key...) + slice = append(slice, value...) + ds.slices[ds.sliceIdx] = slice + ds.sliceLen += dataLen + return addr +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func (ds *dataStore) get(e entry, key []byte) []byte { + slice := ds.slices[e.addr.sliceIdx] + valOffset := e.addr.offset + e.keyLen + if bytes.Compare(key, slice[e.addr.offset:valOffset]) != 0 { + return nil + } + return slice[valOffset : valOffset+e.valLen] +} + +var nullEntryAddr = entryAddr{} + +func (es *entryStore) put(e entry) entryAddr { + if es.sliceLen == maxEntrySliceLen { + es.slices = append(es.slices, make([]entry, 0, maxEntrySliceLen)) + es.sliceLen = 0 + es.sliceIdx++ + } + addr := entryAddr{sliceIdx: es.sliceIdx, offset: es.sliceLen} + slice := es.slices[es.sliceIdx] + slice = append(slice, e) + es.slices[es.sliceIdx] = slice + es.sliceLen++ + return addr +} + +func (es *entryStore) get(addr entryAddr) entry { + return es.slices[addr.sliceIdx][addr.offset] +} + +// MVMap stores multiple value for a given key with minimum GC overhead. +// A given key can store multiple values. +// It is not thread-safe, should only be used in one goroutine. +type MVMap struct { + entryStore entryStore + dataStore dataStore + hashTable map[uint64]entryAddr + hashFunc hash.Hash64 +} + +// NewMVMap creates a new multi-value map. +func NewMVMap() *MVMap { + m := new(MVMap) + m.hashTable = make(map[uint64]entryAddr) + m.hashFunc = fnv.New64() + m.entryStore.slices = [][]entry{make([]entry, 0, 64)} + // append first empty entry so zero entry pointer an represent null. + m.entryStore.put(entry{}) + m.dataStore.slices = [][]byte{make([]byte, 0, 1024)} + return m +} + +// Put puts the key/value pairs to the MVMap, if the key already exists, old value will not be overwritten, +// values are stored in a list. +func (m *MVMap) Put(key, value []byte) { + hashKey := m.hash(key) + oldEntryAddr := m.hashTable[hashKey] + dataAddr := m.dataStore.put(key, value) + e := entry{ + addr: dataAddr, + keyLen: uint32(len(key)), + valLen: uint32(len(value)), + next: oldEntryAddr, + } + newEntryPtr := m.entryStore.put(e) + m.hashTable[hashKey] = newEntryPtr +} + +// Get gets the values of the key. +func (m *MVMap) Get(key []byte) [][]byte { + var values [][]byte + hashKey := m.hash(key) + entryAddr := m.hashTable[hashKey] + for entryAddr != nullEntryAddr { + e := m.entryStore.get(entryAddr) + entryAddr = e.next + val := m.dataStore.get(e, key) + if val == nil { + continue + } + values = append(values, val) + } + return values +} + +func (m *MVMap) hash(key []byte) uint64 { + m.hashFunc.Reset() + m.hashFunc.Write(key) + return m.hashFunc.Sum64() +} diff --git a/util/mvmap/mvmap_test.go b/util/mvmap/mvmap_test.go new file mode 100644 index 0000000000..b5391c57c1 --- /dev/null +++ b/util/mvmap/mvmap_test.go @@ -0,0 +1,63 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package mvmap + +import ( + "bytes" + "encoding/binary" + "fmt" + "testing" +) + +func TestMVMap(t *testing.T) { + m := NewMVMap() + m.Put([]byte("abc"), []byte("abc1")) + m.Put([]byte("abc"), []byte("abc2")) + m.Put([]byte("def"), []byte("def1")) + m.Put([]byte("def"), []byte("def2")) + vals := m.Get([]byte("abc")) + if fmt.Sprintf("%s", vals) != "[abc2 abc1]" { + t.FailNow() + } + vals = m.Get([]byte("def")) + if fmt.Sprintf("%s", vals) != "[def2 def1]" { + t.FailNow() + } +} + +func BenchmarkMVMapPut(b *testing.B) { + m := NewMVMap() + buffer := make([]byte, 8) + for i := 0; i < b.N; i++ { + binary.BigEndian.PutUint64(buffer, uint64(i)) + m.Put(buffer, buffer) + } +} + +func BenchmarkMVMapGet(b *testing.B) { + m := NewMVMap() + buffer := make([]byte, 8) + for i := 0; i < b.N; i++ { + binary.BigEndian.PutUint64(buffer, uint64(i)) + m.Put(buffer, buffer) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + binary.BigEndian.PutUint64(buffer, uint64(i)) + val := m.Get(buffer) + if len(val) != 1 || bytes.Compare(val[0], buffer) != 0 { + b.FailNow() + } + } +}