diff --git a/inspectkv/inspectkv.go b/inspectkv/inspectkv.go index 676acdd64d..e2a51c0a95 100644 --- a/inspectkv/inspectkv.go +++ b/inspectkv/inspectkv.go @@ -209,24 +209,24 @@ func checkColsAndIndex(txn kv.Transaction, t table.Table, idx *column.IndexedCol startKey := t.RecordKey(0, nil) kvIndex := kv.NewKVIndex(t.IndexPrefix(), idx.Name.L, idx.ID, idx.Unique) - err := t.IterRecords(txn, string(startKey), cols, - func(h1 int64, vals1 []interface{}, cols []*column.Col) (bool, error) { - isExist, h2, err := kvIndex.Exist(txn, vals1, h1) - if terror.ErrorEqual(err, terror.ErrKeyExists) { - record1 := &RecordData{Handle: h1, Values: vals1} - record2 := &RecordData{Handle: h2, Values: vals1} - return false, errors.Errorf("index:%v != record:%v", record2, record1) - } - if err != nil { - return false, errors.Trace(err) - } - if !isExist { - record := &RecordData{Handle: h1, Values: vals1} - return false, errors.Errorf("index:%v != record:%v", nil, record) - } + filterFunc := func(h1 int64, vals1 []interface{}, cols []*column.Col) (bool, error) { + isExist, h2, err := kvIndex.Exist(txn, vals1, h1) + if terror.ErrorEqual(err, terror.ErrKeyExists) { + record1 := &RecordData{Handle: h1, Values: vals1} + record2 := &RecordData{Handle: h2, Values: vals1} + return false, errors.Errorf("index:%v != record:%v", record2, record1) + } + if err != nil { + return false, errors.Trace(err) + } + if !isExist { + record := &RecordData{Handle: h1, Values: vals1} + return false, errors.Errorf("index:%v != record:%v", nil, record) + } - return true, nil - }) + return true, nil + } + err := t.IterRecords(txn, string(startKey), cols, filterFunc) if err != nil { return errors.Trace(err) @@ -240,20 +240,20 @@ func scanTableData(retriever kv.Retriever, t table.Table, cols []*column.Col, st var records []*RecordData startKey := t.RecordKey(startHandle, nil) - err := t.IterRecords(retriever, string(startKey), cols, - func(h int64, d []interface{}, cols []*column.Col) (bool, error) { - if limit != 0 { - r := &RecordData{ - Handle: h, - Values: d, - } - records = append(records, r) - limit-- - return true, nil + filterFunc := func(h int64, d []interface{}, cols []*column.Col) (bool, error) { + if limit != 0 { + r := &RecordData{ + Handle: h, + Values: d, } + records = append(records, r) + limit-- + return true, nil + } - return false, nil - }) + return false, nil + } + err := t.IterRecords(retriever, string(startKey), cols, filterFunc) if err != nil { return nil, 0, errors.Trace(err) } @@ -300,6 +300,7 @@ func ScanSnapshotTableData(store kv.Storage, ver kv.Version, t table.Table, star func CompareTableData(txn kv.Transaction, t table.Table, data []*RecordData, exact bool) error { var err error var vals []interface{} + m := make(map[int64]struct{}, len(data)) for _, r := range data { vals, err = t.RowWithCols(txn, r.Handle, t.Cols()) @@ -312,6 +313,7 @@ func CompareTableData(txn kv.Transaction, t table.Table, data []*RecordData, exa } if !exact { + m[r.Handle] = struct{}{} continue } if !reflect.DeepEqual(r.Values, vals) { @@ -320,6 +322,7 @@ func CompareTableData(txn kv.Transaction, t table.Table, data []*RecordData, exa err = errors.Errorf("data:%v != record:%v", record1, record2) break } + m[r.Handle] = struct{}{} } if err != nil { return errors.Trace(err) @@ -327,26 +330,17 @@ func CompareTableData(txn kv.Transaction, t table.Table, data []*RecordData, exa startKey := t.RecordKey(0, nil) filterFunc := func(h int64, vals []interface{}, cols []*column.Col) (bool, error) { - for _, r := range data { - if !exact { - if r.Handle == h { - return true, nil - } - continue - } - if r.Handle == h && reflect.DeepEqual(r.Values, vals) { - return true, nil - } + if _, ok := m[h]; !ok { + record := &RecordData{Handle: h, Values: vals} + return false, errors.Errorf("data:%v != record:%v", nil, record) + } - record := &RecordData{Handle: h, Values: vals} - return false, errors.Errorf("data:%v != record:%v", nil, record) + + return true, nil } err = t.IterRecords(txn, string(startKey), t.Cols(), filterFunc) - if err != nil { - return errors.Trace(err) - } - return nil + return errors.Trace(err) } // GetTableRecordsCount returns the total number of table records from startHandle.