diff --git a/store/tikv/isolation_test.go b/store/tikv/isolation_test.go new file mode 100644 index 0000000000..d2d63fc794 --- /dev/null +++ b/store/tikv/isolation_test.go @@ -0,0 +1,184 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package tikv + +import ( + "fmt" + "sort" + "sync" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/kv" +) + +type testIsolationSuite struct { + store *tikvStore +} + +var _ = Suite(&testIsolationSuite{}) + +func (s *testIsolationSuite) SetUpSuite(c *C) { + s.store = newTestStore(c) +} + +type writeRecord struct { + startTS uint64 + commitTS uint64 +} + +type writeRecords []writeRecord + +func (r writeRecords) Len() int { return len(r) } +func (r writeRecords) Swap(i, j int) { r[i], r[j] = r[j], r[i] } +func (r writeRecords) Less(i, j int) bool { return r[i].startTS <= r[j].startTS } + +func (s *testIsolationSuite) SetWithRetry(c *C, k, v []byte) writeRecord { + for { + txn, err := s.store.Begin() + c.Assert(err, IsNil) + + err = txn.Set(k, v) + c.Assert(err, IsNil) + + err = txn.Commit() + if err == nil { + return writeRecord{ + startTS: txn.StartTS(), + commitTS: txn.(*tikvTxn).commitTS, + } + } + c.Assert(kv.IsRetryableError(err), IsTrue) + } +} + +type readRecord struct { + startTS uint64 + value []byte +} + +type readRecords []readRecord + +func (r readRecords) Len() int { return len(r) } +func (r readRecords) Swap(i, j int) { r[i], r[j] = r[j], r[i] } +func (r readRecords) Less(i, j int) bool { return r[i].startTS <= r[j].startTS } + +func (s *testIsolationSuite) GetWithRetry(c *C, k []byte) readRecord { + for { + txn, err := s.store.Begin() + c.Assert(err, IsNil) + + val, err := txn.Get(k) + if err == nil { + return readRecord{ + startTS: txn.StartTS(), + value: val, + } + } + c.Assert(kv.IsRetryableError(err), IsTrue) + } +} + +func (s *testIsolationSuite) TestWriteWriteConflict(c *C) { + const ( + threadCount = 10 + setPerThread = 100 + ) + var ( + mu sync.Mutex + writes []writeRecord + wg sync.WaitGroup + ) + wg.Add(threadCount) + for i := 0; i < threadCount; i++ { + go func() { + defer wg.Done() + for j := 0; j < setPerThread; j++ { + w := s.SetWithRetry(c, []byte("k"), []byte("v")) + mu.Lock() + writes = append(writes, w) + mu.Unlock() + } + }() + } + wg.Wait() + + // Check all transactions' [startTS, commitTS] are not overlapped. + sort.Sort(writeRecords(writes)) + for i := 0; i < len(writes)-1; i++ { + c.Assert(writes[i].commitTS, Less, writes[i+1].startTS) + } +} + +func (s *testIsolationSuite) TestReadWriteConflict(c *C) { + const ( + readThreadCount = 10 + writeCount = 10 + ) + + var ( + writes []writeRecord + mu sync.Mutex + reads []readRecord + wg sync.WaitGroup + ) + + s.SetWithRetry(c, []byte("k"), []byte("0")) + + writeDone := make(chan struct{}) + go func() { + for i := 1; i <= writeCount; i++ { + w := s.SetWithRetry(c, []byte("k"), []byte(fmt.Sprintf("%d", i))) + writes = append(writes, w) + time.Sleep(time.Microsecond * 10) + } + close(writeDone) + }() + + wg.Add(readThreadCount) + for i := 0; i < readThreadCount; i++ { + go func() { + defer wg.Done() + for { + select { + case <-writeDone: + return + default: + } + r := s.GetWithRetry(c, []byte("k")) + mu.Lock() + reads = append(reads, r) + mu.Unlock() + } + }() + } + wg.Wait() + + sort.Sort(readRecords(reads)) + + // Check all reads got the value committed before it's startTS. + var i, j int + for ; i < len(writes); i++ { + for ; j < len(reads); j++ { + w, r := writes[i], reads[j] + if r.startTS >= w.commitTS { + break + } + c.Assert(string(r.value), Equals, fmt.Sprintf("%d", i)) + } + } + for ; j < len(reads); j++ { + c.Assert(string(reads[j].value), Equals, fmt.Sprintf("%d", len(writes))) + } +} diff --git a/store/tikv/txn.go b/store/tikv/txn.go index e83a0ea0aa..cb6a6b0dc0 100644 --- a/store/tikv/txn.go +++ b/store/tikv/txn.go @@ -120,6 +120,7 @@ func (txn *tikvTxn) Commit() error { if err != nil { return errors.Trace(err) } + txn.commitTS = committer.commitTS log.Debugf("[kv] finish commit txn %d", txn.StartTS()) return nil }