969 lines
26 KiB
Go
969 lines
26 KiB
Go
// Copyright 2013 The ql Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSES/QL-LICENSE file.
|
|
|
|
// Copyright 2015 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package tidb
|
|
|
|
import (
|
|
goctx "context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/juju/errors"
|
|
"github.com/ngaut/log"
|
|
"github.com/pingcap/tidb/ast"
|
|
"github.com/pingcap/tidb/context"
|
|
"github.com/pingcap/tidb/domain"
|
|
"github.com/pingcap/tidb/executor"
|
|
"github.com/pingcap/tidb/kv"
|
|
"github.com/pingcap/tidb/meta"
|
|
"github.com/pingcap/tidb/mysql"
|
|
"github.com/pingcap/tidb/parser"
|
|
"github.com/pingcap/tidb/perfschema"
|
|
"github.com/pingcap/tidb/privilege"
|
|
"github.com/pingcap/tidb/privilege/privileges"
|
|
"github.com/pingcap/tidb/sessionctx"
|
|
"github.com/pingcap/tidb/sessionctx/binloginfo"
|
|
"github.com/pingcap/tidb/sessionctx/variable"
|
|
"github.com/pingcap/tidb/sessionctx/varsutil"
|
|
"github.com/pingcap/tidb/store/localstore"
|
|
"github.com/pingcap/tidb/terror"
|
|
"github.com/pingcap/tidb/util/types"
|
|
"github.com/pingcap/tipb/go-binlog"
|
|
)
|
|
|
|
// Session context
|
|
type Session interface {
|
|
context.Context
|
|
Status() uint16 // Flag of current status, such as autocommit.
|
|
LastInsertID() uint64 // Last inserted auto_increment id.
|
|
AffectedRows() uint64 // Affected rows by latest executed stmt.
|
|
Execute(sql string) ([]ast.RecordSet, error) // Execute a sql statement.
|
|
String() string // For debug
|
|
CommitTxn() error
|
|
RollbackTxn() error
|
|
// For execute prepare statement in binary protocol.
|
|
PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error)
|
|
// Execute a prepared statement.
|
|
ExecutePreparedStmt(stmtID uint32, param ...interface{}) (ast.RecordSet, error)
|
|
DropPreparedStmt(stmtID uint32) error
|
|
SetClientCapability(uint32) // Set client capability flags.
|
|
SetConnectionID(uint64)
|
|
Close() error
|
|
Auth(user string, auth []byte, salt []byte) bool
|
|
// Cancel the execution of current transaction.
|
|
Cancel()
|
|
}
|
|
|
|
var (
|
|
_ Session = (*session)(nil)
|
|
sessionMu sync.Mutex
|
|
)
|
|
|
|
type stmtRecord struct {
|
|
stmtID uint32
|
|
st ast.Statement
|
|
params []interface{}
|
|
}
|
|
|
|
type stmtHistory struct {
|
|
history []*stmtRecord
|
|
}
|
|
|
|
func (h *stmtHistory) add(stmtID uint32, st ast.Statement, params ...interface{}) {
|
|
s := &stmtRecord{
|
|
stmtID: stmtID,
|
|
st: st,
|
|
params: append(([]interface{})(nil), params...),
|
|
}
|
|
h.history = append(h.history, s)
|
|
}
|
|
|
|
type session struct {
|
|
txn kv.Transaction // current transaction
|
|
txnCh chan *txnWithErr
|
|
// For cancel the execution of current transaction.
|
|
goCtx goctx.Context
|
|
cancelFunc goctx.CancelFunc
|
|
|
|
values map[fmt.Stringer]interface{}
|
|
store kv.Storage
|
|
|
|
// Used for test only.
|
|
unlimitedRetryCount bool
|
|
|
|
// For performance_schema only.
|
|
stmtState *perfschema.StatementState
|
|
parser *parser.Parser
|
|
|
|
sessionVars *variable.SessionVars
|
|
}
|
|
|
|
// Cancel cancels the execution of current transaction.
|
|
func (s *session) Cancel() {
|
|
// TODO: How to wait for the resource to release and make sure
|
|
// it's not leak?
|
|
s.cancelFunc()
|
|
}
|
|
|
|
// Canceled implements context.Context interface.
|
|
func (s *session) Done() <-chan struct{} {
|
|
return s.goCtx.Done()
|
|
}
|
|
|
|
func (s *session) cleanRetryInfo() {
|
|
if !s.sessionVars.RetryInfo.Retrying {
|
|
retryInfo := s.sessionVars.RetryInfo
|
|
for _, stmtID := range retryInfo.DroppedPreparedStmtIDs {
|
|
delete(s.sessionVars.PreparedStmts, stmtID)
|
|
}
|
|
retryInfo.Clean()
|
|
}
|
|
}
|
|
|
|
func (s *session) Status() uint16 {
|
|
return s.sessionVars.Status
|
|
}
|
|
|
|
func (s *session) LastInsertID() uint64 {
|
|
return s.sessionVars.LastInsertID
|
|
}
|
|
|
|
func (s *session) AffectedRows() uint64 {
|
|
return s.sessionVars.StmtCtx.AffectedRows()
|
|
}
|
|
|
|
func (s *session) SetClientCapability(capability uint32) {
|
|
s.sessionVars.ClientCapability = capability
|
|
}
|
|
|
|
func (s *session) SetConnectionID(connectionID uint64) {
|
|
s.sessionVars.ConnectionID = connectionID
|
|
}
|
|
|
|
type schemaLeaseChecker struct {
|
|
domain.SchemaValidator
|
|
schemaVer int64
|
|
}
|
|
|
|
const (
|
|
schemaOutOfDateRetryInterval = 500 * time.Millisecond
|
|
schemaOutOfDateRetryTimes = 10
|
|
)
|
|
|
|
func (s *schemaLeaseChecker) Check(txnTS uint64) error {
|
|
for i := 0; i < schemaOutOfDateRetryTimes; i++ {
|
|
err := s.checkOnce(txnTS)
|
|
switch err {
|
|
case nil:
|
|
return nil
|
|
case domain.ErrInfoSchemaChanged:
|
|
schemaLeaseErrorCounter.WithLabelValues("changed").Inc()
|
|
return errors.Trace(err)
|
|
default:
|
|
schemaLeaseErrorCounter.WithLabelValues("outdated").Inc()
|
|
time.Sleep(schemaOutOfDateRetryInterval)
|
|
}
|
|
}
|
|
return domain.ErrInfoSchemaExpired
|
|
}
|
|
|
|
func (s *schemaLeaseChecker) checkOnce(txnTS uint64) error {
|
|
succ := s.SchemaValidator.Check(txnTS, s.schemaVer)
|
|
if !succ {
|
|
if s.SchemaValidator.Latest() > s.schemaVer {
|
|
return domain.ErrInfoSchemaChanged
|
|
}
|
|
return domain.ErrInfoSchemaExpired
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *session) doCommit() error {
|
|
if s.txn == nil || !s.txn.Valid() {
|
|
return nil
|
|
}
|
|
defer func() {
|
|
s.txn = nil
|
|
s.sessionVars.SetStatusFlag(mysql.ServerStatusInTrans, false)
|
|
}()
|
|
if binloginfo.PumpClient != nil {
|
|
prewriteValue := binloginfo.GetPrewriteValue(s, false)
|
|
if prewriteValue != nil {
|
|
prewriteData, err := prewriteValue.Marshal()
|
|
if err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
bin := &binlog.Binlog{
|
|
Tp: binlog.BinlogType_Prewrite,
|
|
PrewriteValue: prewriteData,
|
|
}
|
|
s.txn.SetOption(kv.BinlogData, bin)
|
|
}
|
|
}
|
|
|
|
// Set this option for 2 phase commit to validate schema lease.
|
|
s.txn.SetOption(kv.SchemaLeaseChecker, &schemaLeaseChecker{
|
|
SchemaValidator: sessionctx.GetDomain(s).SchemaValidator,
|
|
schemaVer: s.sessionVars.TxnCtx.SchemaVersion,
|
|
})
|
|
if err := s.txn.Commit(); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *session) doCommitWithRetry() error {
|
|
var txnSize int
|
|
if s.txn != nil && s.txn.Valid() {
|
|
txnSize = s.txn.Size()
|
|
}
|
|
err := s.doCommit()
|
|
if err != nil {
|
|
if s.isRetryableError(err) {
|
|
// Transactions will retry 2 ~ commitRetryLimit times.
|
|
// We make larger transactions retry less times to prevent cluster resource outage.
|
|
txnSizeRate := float64(txnSize) / float64(kv.TxnTotalSizeLimit)
|
|
maxRetryCount := commitRetryLimit - int(float64(commitRetryLimit-1)*txnSizeRate)
|
|
err = s.retry(maxRetryCount)
|
|
}
|
|
}
|
|
s.cleanRetryInfo()
|
|
if err != nil {
|
|
log.Warnf("[%d] finished txn:%v, %v", s.sessionVars.ConnectionID, s.txn, err)
|
|
return errors.Trace(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *session) CommitTxn() error {
|
|
return s.doCommitWithRetry()
|
|
}
|
|
|
|
func (s *session) RollbackTxn() error {
|
|
var err error
|
|
if s.txn != nil && s.txn.Valid() {
|
|
err = s.txn.Rollback()
|
|
}
|
|
s.cleanRetryInfo()
|
|
s.txn = nil
|
|
s.txnCh = nil
|
|
s.sessionVars.SetStatusFlag(mysql.ServerStatusInTrans, false)
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
func (s *session) GetClient() kv.Client {
|
|
return s.store.GetClient()
|
|
}
|
|
|
|
func (s *session) String() string {
|
|
// TODO: how to print binded context in values appropriately?
|
|
sessVars := s.sessionVars
|
|
data := map[string]interface{}{
|
|
"id": sessVars.ConnectionID,
|
|
"user": sessVars.User,
|
|
"currDBName": sessVars.CurrentDB,
|
|
"stauts": sessVars.Status,
|
|
"strictMode": sessVars.StrictSQLMode,
|
|
}
|
|
if s.txn != nil {
|
|
// if txn is committed or rolled back, txn is nil.
|
|
data["txn"] = s.txn.String()
|
|
}
|
|
if sessVars.SnapshotTS != 0 {
|
|
data["snapshotTS"] = sessVars.SnapshotTS
|
|
}
|
|
if sessVars.LastInsertID > 0 {
|
|
data["lastInsertID"] = sessVars.LastInsertID
|
|
}
|
|
if len(sessVars.PreparedStmts) > 0 {
|
|
data["preparedStmtCount"] = len(sessVars.PreparedStmts)
|
|
}
|
|
b, _ := json.MarshalIndent(data, "", " ")
|
|
return string(b)
|
|
}
|
|
|
|
const sqlLogMaxLen = 1024
|
|
|
|
func (s *session) isRetryableError(err error) bool {
|
|
return kv.IsRetryableError(err) || terror.ErrorEqual(err, domain.ErrInfoSchemaChanged)
|
|
}
|
|
|
|
func (s *session) retry(maxCnt int) error {
|
|
connID := s.sessionVars.ConnectionID
|
|
if s.sessionVars.TxnCtx.ForUpdate {
|
|
return errors.Errorf("[%d] can not retry select for update statement", connID)
|
|
}
|
|
s.sessionVars.RetryInfo.Retrying = true
|
|
retryCnt := 0
|
|
defer func() {
|
|
s.sessionVars.RetryInfo.Retrying = false
|
|
sessionRetry.Observe(float64(retryCnt))
|
|
}()
|
|
nh := getHistory(s)
|
|
var err error
|
|
for {
|
|
s.prepareTxnCtx()
|
|
s.sessionVars.RetryInfo.ResetOffset()
|
|
for _, sr := range nh.history {
|
|
st := sr.st
|
|
txt := st.OriginText()
|
|
log.Warnf("[%d] Retry %s", connID, sqlForLog(txt))
|
|
_, err = st.Exec(s)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
if err == nil {
|
|
err = s.doCommit()
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
if !s.isRetryableError(err) {
|
|
log.Warnf("[%d] session:%v, err:%v", connID, s, err)
|
|
return errors.Trace(err)
|
|
}
|
|
retryCnt++
|
|
if !s.unlimitedRetryCount && (retryCnt >= maxCnt) {
|
|
log.Warnf("[%id] Retry reached max count %d", connID, retryCnt)
|
|
return errors.Trace(err)
|
|
}
|
|
kv.BackOff(retryCnt)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func sqlForLog(sql string) string {
|
|
if len(sql) > sqlLogMaxLen {
|
|
return sql[:sqlLogMaxLen] + fmt.Sprintf("(len:%d)", len(sql))
|
|
}
|
|
return sql
|
|
}
|
|
|
|
var sysSessionPool = sync.Pool{}
|
|
|
|
// ExecRestrictedSQL implements RestrictedSQLExecutor interface.
|
|
// This is used for executing some restricted sql statements, usually executed during a normal statement execution.
|
|
// Unlike normal Exec, it doesn't reset statement status, doesn't commit or rollback the current transaction
|
|
// and doesn't write binlog.
|
|
func (s *session) ExecRestrictedSQL(ctx context.Context, sql string) ([]*ast.Row, []*ast.ResultField, error) {
|
|
// Use special session to execute the sql.
|
|
var se *session
|
|
tmp := sysSessionPool.Get()
|
|
if tmp != nil {
|
|
se = tmp.(*session)
|
|
} else {
|
|
var err error
|
|
se, err = createSession(s.store)
|
|
if err != nil {
|
|
return nil, nil, errors.Trace(err)
|
|
}
|
|
varsutil.SetSessionSystemVar(se.sessionVars, variable.AutocommitVar, types.NewStringDatum("1"))
|
|
se.sessionVars.CommonGlobalLoaded = true
|
|
se.sessionVars.InRestrictedSQL = true
|
|
}
|
|
defer sysSessionPool.Put(se)
|
|
|
|
recordSets, err := se.Execute(sql)
|
|
if err != nil {
|
|
return nil, nil, errors.Trace(err)
|
|
}
|
|
|
|
var (
|
|
rows []*ast.Row
|
|
fields []*ast.ResultField
|
|
)
|
|
// Execute all recordset, take out the first one as result.
|
|
for i, rs := range recordSets {
|
|
tmp, err := drainRecordSet(rs)
|
|
if err != nil {
|
|
return nil, nil, errors.Trace(err)
|
|
}
|
|
if err = rs.Close(); err != nil {
|
|
return nil, nil, errors.Trace(err)
|
|
}
|
|
|
|
if i == 0 {
|
|
rows = tmp
|
|
fields, err = rs.Fields()
|
|
if err != nil {
|
|
return nil, nil, errors.Trace(err)
|
|
}
|
|
}
|
|
}
|
|
return rows, fields, nil
|
|
}
|
|
|
|
func drainRecordSet(rs ast.RecordSet) ([]*ast.Row, error) {
|
|
var rows []*ast.Row
|
|
for {
|
|
row, err := rs.Next()
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
if row == nil {
|
|
break
|
|
}
|
|
rows = append(rows, row)
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
// getExecRet executes restricted sql and the result is one column.
|
|
// It returns a string value.
|
|
func (s *session) getExecRet(ctx context.Context, sql string) (string, error) {
|
|
rows, _, err := s.ExecRestrictedSQL(ctx, sql)
|
|
if err != nil {
|
|
return "", errors.Trace(err)
|
|
}
|
|
if len(rows) == 0 {
|
|
return "", terror.ExecResultIsEmpty
|
|
}
|
|
value, err := types.ToString(rows[0].Data[0].GetValue())
|
|
if err != nil {
|
|
return "", errors.Trace(err)
|
|
}
|
|
return value, nil
|
|
}
|
|
|
|
// GetGlobalSysVar implements GlobalVarAccessor.GetGlobalSysVar interface.
|
|
func (s *session) GetGlobalSysVar(name string) (string, error) {
|
|
if s.Value(context.Initing) != nil {
|
|
// When running bootstrap or upgrade, we should not access global storage.
|
|
return "", nil
|
|
}
|
|
sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`,
|
|
mysql.SystemDB, mysql.GlobalVariablesTable, name)
|
|
sysVar, err := s.getExecRet(s, sql)
|
|
if err != nil {
|
|
if terror.ExecResultIsEmpty.Equal(err) {
|
|
return "", variable.UnknownSystemVar.GenByArgs(name)
|
|
}
|
|
return "", errors.Trace(err)
|
|
}
|
|
return sysVar, nil
|
|
}
|
|
|
|
// SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface.
|
|
func (s *session) SetGlobalSysVar(name string, value string) error {
|
|
sql := fmt.Sprintf(`UPDATE %s.%s SET VARIABLE_VALUE="%s" WHERE VARIABLE_NAME="%s";`,
|
|
mysql.SystemDB, mysql.GlobalVariablesTable, value, strings.ToLower(name))
|
|
_, _, err := s.ExecRestrictedSQL(s, sql)
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
func (s *session) ParseSQL(sql, charset, collation string) ([]ast.StmtNode, error) {
|
|
return s.parser.Parse(sql, charset, collation)
|
|
}
|
|
|
|
func (s *session) Execute(sql string) ([]ast.RecordSet, error) {
|
|
s.prepareTxnCtx()
|
|
startTS := time.Now()
|
|
charset, collation := s.sessionVars.GetCharsetInfo()
|
|
connID := s.sessionVars.ConnectionID
|
|
rawStmts, err := s.ParseSQL(sql, charset, collation)
|
|
if err != nil {
|
|
log.Warnf("[%d] parse error:\n%v\n%s", connID, err, sql)
|
|
return nil, errors.Trace(err)
|
|
}
|
|
sessionExecuteParseDuration.Observe(time.Since(startTS).Seconds())
|
|
|
|
var rs []ast.RecordSet
|
|
ph := sessionctx.GetDomain(s).PerfSchema()
|
|
for i, rst := range rawStmts {
|
|
s.prepareTxnCtx()
|
|
startTS := time.Now()
|
|
// Some execution is done in compile stage, so we reset it before compile.
|
|
resetStmtCtx(s, rst)
|
|
st, err1 := Compile(s, rst)
|
|
if err1 != nil {
|
|
log.Warnf("[%d] compile error:\n%v\n%s", connID, err1, sql)
|
|
s.RollbackTxn()
|
|
return nil, errors.Trace(err1)
|
|
}
|
|
sessionExecuteCompileDuration.Observe(time.Since(startTS).Seconds())
|
|
|
|
s.stmtState = ph.StartStatement(sql, connID, perfschema.CallerNameSessionExecute, rawStmts[i])
|
|
s.SetValue(context.QueryString, st.OriginText())
|
|
|
|
startTS = time.Now()
|
|
r, err := runStmt(s, st)
|
|
ph.EndStatement(s.stmtState)
|
|
if err != nil {
|
|
if !terror.ErrorEqual(err, kv.ErrKeyExists) {
|
|
log.Warnf("[%d] session error:\n%v\n%s", connID, errors.ErrorStack(err), s)
|
|
}
|
|
return nil, errors.Trace(err)
|
|
}
|
|
sessionExecuteRunDuration.Observe(time.Since(startTS).Seconds())
|
|
if r != nil {
|
|
rs = append(rs, r)
|
|
}
|
|
}
|
|
|
|
if s.sessionVars.ClientCapability&mysql.ClientMultiResults == 0 && len(rs) > 1 {
|
|
// return the first recordset if client doesn't support ClientMultiResults.
|
|
rs = rs[:1]
|
|
}
|
|
return rs, nil
|
|
}
|
|
|
|
// For execute prepare statement in binary protocol
|
|
func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error) {
|
|
if s.sessionVars.TxnCtx.InfoSchema == nil {
|
|
// We don't need to create a transaction for prepare statement, just get information schema will do.
|
|
s.sessionVars.TxnCtx.InfoSchema = sessionctx.GetDomain(s).InfoSchema()
|
|
}
|
|
prepareExec := &executor.PrepareExec{
|
|
IS: executor.GetInfoSchema(s),
|
|
Ctx: s,
|
|
SQLText: sql,
|
|
}
|
|
prepareExec.DoPrepare()
|
|
return prepareExec.ID, prepareExec.ParamCount, prepareExec.Fields, prepareExec.Err
|
|
}
|
|
|
|
// checkArgs makes sure all the arguments' types are known and can be handled.
|
|
// integer types are converted to int64 and uint64, time.Time is converted to types.Time.
|
|
// time.Duration is converted to types.Duration, other known types are leaved as it is.
|
|
func checkArgs(args ...interface{}) error {
|
|
for i, v := range args {
|
|
switch x := v.(type) {
|
|
case bool:
|
|
if x {
|
|
args[i] = int64(1)
|
|
} else {
|
|
args[i] = int64(0)
|
|
}
|
|
case int8:
|
|
args[i] = int64(x)
|
|
case int16:
|
|
args[i] = int64(x)
|
|
case int32:
|
|
args[i] = int64(x)
|
|
case int:
|
|
args[i] = int64(x)
|
|
case uint8:
|
|
args[i] = uint64(x)
|
|
case uint16:
|
|
args[i] = uint64(x)
|
|
case uint32:
|
|
args[i] = uint64(x)
|
|
case uint:
|
|
args[i] = uint64(x)
|
|
case int64:
|
|
case uint64:
|
|
case float32:
|
|
case float64:
|
|
case string:
|
|
case []byte:
|
|
case time.Duration:
|
|
args[i] = types.Duration{Duration: x}
|
|
case time.Time:
|
|
args[i] = types.Time{Time: types.FromGoTime(x), Type: mysql.TypeDatetime}
|
|
case nil:
|
|
default:
|
|
return errors.Errorf("cannot use arg[%d] (type %T):unsupported type", i, v)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ExecutePreparedStmt executes a prepared statement.
|
|
func (s *session) ExecutePreparedStmt(stmtID uint32, args ...interface{}) (ast.RecordSet, error) {
|
|
err := checkArgs(args...)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
s.prepareTxnCtx()
|
|
st := executor.CompileExecutePreparedStmt(s, stmtID, args...)
|
|
r, err := runStmt(s, st)
|
|
return r, errors.Trace(err)
|
|
}
|
|
|
|
func (s *session) DropPreparedStmt(stmtID uint32) error {
|
|
vars := s.sessionVars
|
|
if _, ok := vars.PreparedStmts[stmtID]; !ok {
|
|
return executor.ErrStmtNotFound
|
|
}
|
|
vars.RetryInfo.DroppedPreparedStmtIDs = append(vars.RetryInfo.DroppedPreparedStmtIDs, stmtID)
|
|
return nil
|
|
}
|
|
|
|
// If forceNew is true, GetTxn() must return a new transaction.
|
|
// In this situation, if current transaction is still in progress,
|
|
// there will be an implicit commit and create a new transaction.
|
|
func (s *session) GetTxn(forceNew bool) (kv.Transaction, error) {
|
|
if s.txn != nil && !forceNew {
|
|
return s.txn, nil
|
|
}
|
|
|
|
var err error
|
|
var force string
|
|
if s.txn == nil {
|
|
err = s.loadCommonGlobalVariablesIfNeeded()
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
} else if forceNew {
|
|
err = s.CommitTxn()
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
force = "force"
|
|
}
|
|
s.txn, err = s.store.Begin()
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
ac := s.sessionVars.IsAutocommit()
|
|
if !ac {
|
|
s.sessionVars.SetStatusFlag(mysql.ServerStatusInTrans, true)
|
|
}
|
|
log.Infof("[%d] %s new txn:%s", s.sessionVars.ConnectionID, force, s.txn)
|
|
return s.txn, nil
|
|
}
|
|
|
|
func (s *session) Txn() kv.Transaction {
|
|
return s.txn
|
|
}
|
|
|
|
func (s *session) NewTxn() error {
|
|
if s.txn != nil && s.txn.Valid() {
|
|
err := s.doCommitWithRetry()
|
|
if err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
}
|
|
txn, err := s.store.Begin()
|
|
if err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
s.txn = txn
|
|
return nil
|
|
}
|
|
|
|
func (s *session) SetValue(key fmt.Stringer, value interface{}) {
|
|
s.values[key] = value
|
|
}
|
|
|
|
func (s *session) Value(key fmt.Stringer) interface{} {
|
|
value := s.values[key]
|
|
return value
|
|
}
|
|
|
|
func (s *session) ClearValue(key fmt.Stringer) {
|
|
delete(s.values, key)
|
|
}
|
|
|
|
// Close function does some clean work when session end.
|
|
func (s *session) Close() error {
|
|
return s.RollbackTxn()
|
|
}
|
|
|
|
// GetSessionVars implements the context.Context interface.
|
|
func (s *session) GetSessionVars() *variable.SessionVars {
|
|
return s.sessionVars
|
|
}
|
|
|
|
func (s *session) getPassword(name, host string) (string, error) {
|
|
// Get password for name and host.
|
|
authSQL := fmt.Sprintf("SELECT Password FROM %s.%s WHERE User='%s' and Host='%s';", mysql.SystemDB, mysql.UserTable, name, host)
|
|
pwd, err := s.getExecRet(s, authSQL)
|
|
if err == nil {
|
|
return pwd, nil
|
|
} else if !terror.ExecResultIsEmpty.Equal(err) {
|
|
return "", errors.Trace(err)
|
|
}
|
|
//Try to get user password for name with any host(%).
|
|
authSQL = fmt.Sprintf("SELECT Password FROM %s.%s WHERE User='%s' and Host='%%';", mysql.SystemDB, mysql.UserTable, name)
|
|
pwd, err = s.getExecRet(s, authSQL)
|
|
return pwd, errors.Trace(err)
|
|
}
|
|
|
|
func (s *session) Auth(user string, auth []byte, salt []byte) bool {
|
|
strs := strings.Split(user, "@")
|
|
if len(strs) != 2 {
|
|
log.Warnf("Invalid format for user: %s", user)
|
|
return false
|
|
}
|
|
// Get user password.
|
|
name := strs[0]
|
|
host := strs[1]
|
|
|
|
checker := privilege.GetPrivilegeChecker(s)
|
|
if !checker.ConnectionVerification(name, host, auth, salt) {
|
|
log.Errorf("User connection verification failed %v", name)
|
|
return false
|
|
}
|
|
s.sessionVars.User = user
|
|
return true
|
|
}
|
|
|
|
// Some vars name for debug.
|
|
const (
|
|
retryEmptyHistoryList = "RetryEmptyHistoryList"
|
|
)
|
|
|
|
func chooseMinLease(n1 time.Duration, n2 time.Duration) time.Duration {
|
|
if n1 <= n2 {
|
|
return n1
|
|
}
|
|
return n2
|
|
}
|
|
|
|
// CreateSession creates a new session environment.
|
|
func CreateSession(store kv.Storage) (Session, error) {
|
|
s, err := createSession(store)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
// Add auth here.
|
|
do, err := domap.Get(store)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
privChecker := &privileges.UserPrivileges{
|
|
Handle: do.PrivilegeHandle(),
|
|
}
|
|
privilege.BindPrivilegeChecker(s, privChecker)
|
|
|
|
return s, nil
|
|
}
|
|
|
|
// BootstrapSession runs the first time when the TiDB server start.
|
|
func BootstrapSession(store kv.Storage) (*domain.Domain, error) {
|
|
ver := getStoreBootstrapVersion(store)
|
|
if ver == notBootstrapped {
|
|
runInBootstrapSession(store, bootstrap)
|
|
} else if ver < currentBootstrapVersion {
|
|
runInBootstrapSession(store, upgrade)
|
|
}
|
|
|
|
se, err := createSession(store)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
dom := sessionctx.GetDomain(se)
|
|
err = dom.LoadPrivilegeLoop(se)
|
|
|
|
return dom, errors.Trace(err)
|
|
}
|
|
|
|
// runInBootstrapSession create a special session for boostrap to run.
|
|
// If no bootstrap and storage is remote, we must use a little lease time to
|
|
// bootstrap quickly, after bootstrapped, we will reset the lease time.
|
|
// TODO: Using a bootstap tool for doing this may be better later.
|
|
func runInBootstrapSession(store kv.Storage, bootstrap func(Session)) {
|
|
saveLease := schemaLease
|
|
if !localstore.IsLocalStore(store) {
|
|
schemaLease = chooseMinLease(schemaLease, 100*time.Millisecond)
|
|
}
|
|
s, err := createSession(store)
|
|
if err != nil {
|
|
// Bootstrap fail will cause program exit.
|
|
log.Fatal(errors.ErrorStack(err))
|
|
}
|
|
schemaLease = saveLease
|
|
|
|
s.SetValue(context.Initing, true)
|
|
bootstrap(s)
|
|
finishBootstrap(store)
|
|
s.ClearValue(context.Initing)
|
|
|
|
domain := sessionctx.GetDomain(s)
|
|
domain.Close()
|
|
domap.Delete(store)
|
|
}
|
|
|
|
func createSession(store kv.Storage) (*session, error) {
|
|
s := &session{
|
|
values: make(map[fmt.Stringer]interface{}),
|
|
store: store,
|
|
parser: parser.New(),
|
|
sessionVars: variable.NewSessionVars(),
|
|
}
|
|
domain, err := domap.Get(store)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
sessionctx.BindDomain(s, domain)
|
|
// session implements variable.GlobalVarAccessor. Bind it to ctx.
|
|
s.sessionVars.GlobalVarsAccessor = s
|
|
|
|
return s, nil
|
|
}
|
|
|
|
const (
|
|
notBootstrapped = 0
|
|
currentBootstrapVersion = 3
|
|
)
|
|
|
|
func getStoreBootstrapVersion(store kv.Storage) int64 {
|
|
// check in memory
|
|
_, ok := storeBootstrapped[store.UUID()]
|
|
if ok {
|
|
return currentBootstrapVersion
|
|
}
|
|
|
|
var ver int64
|
|
// check in kv store
|
|
err := kv.RunInNewTxn(store, false, func(txn kv.Transaction) error {
|
|
var err error
|
|
t := meta.NewMeta(txn)
|
|
ver, err = t.GetBootstrapVersion()
|
|
return errors.Trace(err)
|
|
})
|
|
|
|
if err != nil {
|
|
log.Fatalf("check bootstrapped err %v", err)
|
|
}
|
|
|
|
if ver > notBootstrapped {
|
|
// here mean memory is not ok, but other server has already finished it
|
|
storeBootstrapped[store.UUID()] = true
|
|
}
|
|
|
|
return ver
|
|
}
|
|
|
|
func finishBootstrap(store kv.Storage) {
|
|
storeBootstrapped[store.UUID()] = true
|
|
|
|
err := kv.RunInNewTxn(store, true, func(txn kv.Transaction) error {
|
|
t := meta.NewMeta(txn)
|
|
err := t.FinishBootstrap(currentBootstrapVersion)
|
|
return errors.Trace(err)
|
|
})
|
|
if err != nil {
|
|
log.Fatalf("finish bootstrap err %v", err)
|
|
}
|
|
}
|
|
|
|
const loadCommonGlobalVarsSQL = "select * from mysql.global_variables where variable_name in ('" +
|
|
variable.AutocommitVar + "', '" +
|
|
variable.SQLModeVar + "', '" +
|
|
variable.DistSQLJoinConcurrencyVar + "', '" +
|
|
variable.MaxAllowedPacket + "', '" +
|
|
variable.DistSQLScanConcurrencyVar + "')"
|
|
|
|
// LoadCommonGlobalVariableIfNeeded loads and applies commonly used global variables for the session.
|
|
func (s *session) loadCommonGlobalVariablesIfNeeded() error {
|
|
vars := s.sessionVars
|
|
if vars.CommonGlobalLoaded {
|
|
return nil
|
|
}
|
|
if s.Value(context.Initing) != nil {
|
|
// When running bootstrap or upgrade, we should not access global storage.
|
|
return nil
|
|
}
|
|
// Set the variable to true to prevent cyclic recursive call.
|
|
vars.CommonGlobalLoaded = true
|
|
rows, _, err := s.ExecRestrictedSQL(s, loadCommonGlobalVarsSQL)
|
|
if err != nil {
|
|
vars.CommonGlobalLoaded = false
|
|
log.Errorf("Failed to load common global variables.")
|
|
return errors.Trace(err)
|
|
}
|
|
for _, row := range rows {
|
|
varName := row.Data[0].GetString()
|
|
if _, ok := vars.Systems[varName]; !ok {
|
|
varsutil.SetSessionSystemVar(s.sessionVars, varName, row.Data[1])
|
|
}
|
|
}
|
|
vars.CommonGlobalLoaded = true
|
|
return nil
|
|
}
|
|
|
|
type txnWithErr struct {
|
|
txn kv.Transaction
|
|
err error
|
|
}
|
|
|
|
// prepareTxnCtx starts a goroutine to begin a transaction if needed, and creates a new transaction context.
|
|
// It is called before we execute a sql query.
|
|
func (s *session) prepareTxnCtx() {
|
|
if s.txn != nil && s.txn.Valid() {
|
|
return
|
|
}
|
|
if s.txnCh != nil {
|
|
return
|
|
}
|
|
txnCh := make(chan *txnWithErr, 1)
|
|
go func() {
|
|
txn, err := s.store.Begin()
|
|
txnCh <- &txnWithErr{txn: txn, err: err}
|
|
}()
|
|
goCtx, cancelFunc := goctx.WithCancel(goctx.Background())
|
|
s.txnCh, s.goCtx, s.cancelFunc = txnCh, goCtx, cancelFunc
|
|
is := sessionctx.GetDomain(s).InfoSchema()
|
|
s.sessionVars.TxnCtx = &variable.TransactionContext{
|
|
InfoSchema: is,
|
|
SchemaVersion: is.SchemaMetaVersion(),
|
|
}
|
|
if !s.sessionVars.IsAutocommit() {
|
|
s.sessionVars.SetStatusFlag(mysql.ServerStatusInTrans, true)
|
|
}
|
|
}
|
|
|
|
// ActivePendingTxn implements Session.ActivePendingTxn interface.
|
|
func (s *session) ActivePendingTxn() error {
|
|
if s.txn != nil && s.txn.Valid() {
|
|
return nil
|
|
}
|
|
if s.txnCh == nil {
|
|
return errors.New("transaction channel is not set")
|
|
}
|
|
txnWithErr := <-s.txnCh
|
|
s.txnCh = nil
|
|
if txnWithErr.err != nil {
|
|
return errors.Trace(txnWithErr.err)
|
|
}
|
|
s.txn = txnWithErr.txn
|
|
err := s.loadCommonGlobalVariablesIfNeeded()
|
|
if err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// InitTxnWithStartTS create a transaction with startTS.
|
|
func (s *session) InitTxnWithStartTS(startTS uint64) error {
|
|
if s.txn != nil && s.txn.Valid() {
|
|
return nil
|
|
}
|
|
if s.txnCh == nil {
|
|
return errors.New("transaction channel is not set")
|
|
}
|
|
// no need to get txn from txnCh since txn should init with startTs
|
|
s.txnCh = nil
|
|
var err error
|
|
s.txn, err = s.store.BeginWithStartTS(startTS)
|
|
if err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
err = s.loadCommonGlobalVariablesIfNeeded()
|
|
if err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
return nil
|
|
}
|