diff --git a/pkg/server/server.go b/pkg/server/server.go index 5aa823e30d..7a967010a3 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -990,6 +990,9 @@ func (s *Server) Kill(connectionID uint64, query bool, maxExecutionTime bool, ru if err := conn.bufReadConn.SetWriteDeadline(time.Now()); err != nil { logutil.BgLogger().Warn("error setting write deadline for kill.", zap.Error(err)) } + if err := conn.bufReadConn.SetReadDeadline(time.Now()); err != nil { + logutil.BgLogger().Warn("error setting read deadline for kill.", zap.Error(err)) + } } } killQuery(conn, maxExecutionTime, runaway) @@ -1021,11 +1024,6 @@ func killQuery(conn *clientConn, maxExecutionTime, runaway bool) { if cancelFunc != nil { cancelFunc() } - if conn.bufReadConn != nil { - if err := conn.bufReadConn.SetReadDeadline(time.Now()); err != nil { - logutil.BgLogger().Warn("error setting read deadline for kill.", zap.Error(err)) - } - } sessVars.SQLKiller.FinishResultSet() } @@ -1052,6 +1050,11 @@ func (s *Server) KillAllConnections() { if err := conn.closeWithoutLock(); err != nil { terror.Log(err) } + if conn.bufReadConn != nil { + if err := conn.bufReadConn.SetReadDeadline(time.Now()); err != nil { + logutil.BgLogger().Warn("error setting read deadline for kill.", zap.Error(err)) + } + } killQuery(conn, false, false) } diff --git a/tests/globalkilltest/BUILD.bazel b/tests/globalkilltest/BUILD.bazel index b8752021bd..5ce5158d24 100644 --- a/tests/globalkilltest/BUILD.bazel +++ b/tests/globalkilltest/BUILD.bazel @@ -9,7 +9,7 @@ go_test( ], embed = [":globalkilltest"], flaky = True, - shard_count = 10, + shard_count = 11, deps = [ "//pkg/testkit/testsetup", "//pkg/util/logutil", diff --git a/tests/globalkilltest/global_kill_test.go b/tests/globalkilltest/global_kill_test.go index becfd10dc7..aa84f1ed69 100644 --- a/tests/globalkilltest/global_kill_test.go +++ b/tests/globalkilltest/global_kill_test.go @@ -840,3 +840,55 @@ func TestConnIDUpgradeAndDowngrade(t *testing.T) { conn.mustBe32(t) conn.Close() } + +func TestKillQueryOnIdleConnection(t *testing.T) { + s := createGlobalKillSuite(t, true) + require.NoErrorf(t, s.pdErr, msgErrConnectPD, s.pdErr) + + // tidb1 & conn1a,conn1b + port1 := *tidbStartPort + 1 + tidb1, err := s.startTiDBWithPD(port1, *tidbStatusPort+1, *pdClientPath) + require.NoError(t, err) + defer s.stopService("tidb1", tidb1, true) + + db1, err := s.connectTiDB(port1) + require.NoError(t, err) + defer db1.Close() + + db2, err := s.connectTiDB(port1) + require.NoError(t, err) + defer db2.Close() + + ctx := context.TODO() + conn1, err := db1.Conn(ctx) + require.NoError(t, err) + defer conn1.Close() + + var connID1 uint64 + err = conn1.QueryRowContext(ctx, "SELECT CONNECTION_ID();").Scan(&connID1) + require.NoError(t, err) + + conn2, err := db2.Conn(ctx) + require.NoError(t, err) + defer conn2.Close() + + rows, err := conn1.QueryContext(ctx, "select 1") + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Err()) + require.NoError(t, rows.Close()) + _, err = conn2.ExecContext(ctx, fmt.Sprintf("KILL QUERY %v", connID1)) + require.NoError(t, err) + // verify connection is still alive + rows, err = conn1.QueryContext(ctx, "select 1") + require.NoError(t, err) + require.True(t, rows.Next()) + require.NoError(t, rows.Err()) + require.NoError(t, rows.Close()) + + _, err = conn2.ExecContext(ctx, fmt.Sprintf("KILL CONNECTION %v", connID1)) + require.NoError(t, err) + // verify connection is closed + _, err = conn1.ExecContext(ctx, "select 1") + require.Error(t, err) +}