diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index 37a5b39876..83d3746ca3 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -105,6 +105,7 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { } var handleVals map[string][]byte + var indexKeys []kv.Key if e.idxInfo != nil { // `SELECT a, b FROM t WHERE (a, b) IN ((1, 2), (1, 2), (2, 1), (1, 2))` should not return duplicated rows dedup := make(map[hack.MutableString]struct{}) @@ -129,6 +130,7 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { return keys[i].Cmp(keys[j]) < 0 }) } + indexKeys = keys // Fetch all handles. handleVals, err = batchGetter.BatchGet(ctx, keys) @@ -180,7 +182,15 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { // Fetch all values. var values map[string][]byte if e.lock { - values, err = e.lockKeys(ctx, keys) + lockKeys := make([]kv.Key, len(keys), len(keys)+len(indexKeys)) + copy(lockKeys, keys) + for _, idxKey := range indexKeys { + // lock the non-exist index key, using len(val) in case BatchGet result contains some zero len entries + if val := handleVals[string(idxKey)]; len(val) == 0 { + lockKeys = append(lockKeys, idxKey) + } + } + values, err = e.lockKeys(ctx, lockKeys) if err != nil { return err } diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 3d875ec5f6..09d1ae4b1b 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -560,14 +560,15 @@ func tryWhereIn2BatchPointGet(ctx sessionctx.Context, selStmt *ast.SelectStmt) * // Try use handle if tbl.PKIsHandle { for _, col := range tbl.Columns { - if mysql.HasPriKeyFlag(col.Flag) { - tryHandle = col.Name.L == colName.Name.Name.L + if mysql.HasPriKeyFlag(col.Flag) && col.Name.L == colName.Name.Name.L { + tryHandle = true fieldType = &col.FieldType whereColNames = append(whereColNames, col.Name.L) break } } - } else { + } + if !tryHandle { // Downgrade to use unique index whereColNames = append(whereColNames, colName.Name.Name.L) } diff --git a/session/pessimistic_test.go b/session/pessimistic_test.go index 18f11b43f6..5eeab27aa3 100644 --- a/session/pessimistic_test.go +++ b/session/pessimistic_test.go @@ -1073,3 +1073,26 @@ func (s *testPessimisticSuite) TestNonAutoCommitWithPessimisticMode(c *C) { tk.MustQuery("select * from t1 where c2 = 1 for update").Check(testkit.Rows("1 1", "2 1", "3 1", "4 1")) tk.MustExec("commit") } + +func (s *testPessimisticSuite) TestBatchPointGetLockIndex(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk2.MustExec("use test") + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (c1 int primary key, c2 int, c3 int, unique key uk(c2))") + tk.MustExec("insert into t1 values (1, 1, 1)") + tk.MustExec("insert into t1 values (5, 5, 5)") + tk.MustExec("insert into t1 values (10, 10, 10)") + tk.MustExec("begin pessimistic") + // the handle does not exist and the index key should be locked as point get executor did + tk.MustQuery("select * from t1 where c2 in (2, 3) for update").Check(testkit.Rows()) + tk2.MustExec("set innodb_lock_wait_timeout = 1") + tk2.MustExec("begin pessimistic") + err := tk2.ExecToErr("insert into t1 values(2, 2, 2)") + c.Assert(tikv.ErrLockWaitTimeout.Equal(err), IsTrue) + err = tk2.ExecToErr("select * from t1 where c2 = 3 for update nowait") + c.Assert(tikv.ErrLockAcquireFailAndNoWaitSet.Equal(err), IsTrue) + tk.MustExec("rollback") + tk2.MustExec("rollback") +}