diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index 77eeb99d99..167fa4898f 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -466,3 +466,41 @@ func (s *testSuite) TestUseMultiplyBindings(c *C) { tk.MustQuery("select * from t where a >= 1 and b >= 4 and c = 0") c.Assert(tk.Se.GetSessionVars().StmtCtx.IndexNames[0], Equals, "t:idx_b") } + +func (s *testSuite) TestDropSingleBindings(c *C) { + tk := testkit.NewTestKit(c, s.store) + s.cleanBindingEnv(tk) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idx_b(b))") + + // Test drop session bindings. + tk.MustExec("create binding for select * from t using select * from t use index(idx_a)") + tk.MustExec("create binding for select * from t using select * from t use index(idx_b)") + rows := tk.MustQuery("show bindings").Rows() + c.Assert(len(rows), Equals, 2) + c.Assert(rows[0][1], Equals, "select * from t use index(idx_a)") + c.Assert(rows[1][1], Equals, "select * from t use index(idx_b)") + tk.MustExec("drop binding for select * from t using select * from t use index(idx_a)") + rows = tk.MustQuery("show bindings").Rows() + c.Assert(len(rows), Equals, 1) + c.Assert(rows[0][1], Equals, "select * from t use index(idx_b)") + tk.MustExec("drop binding for select * from t using select * from t use index(idx_b)") + rows = tk.MustQuery("show bindings").Rows() + c.Assert(len(rows), Equals, 0) + + // Test drop global bindings. + tk.MustExec("create global binding for select * from t using select * from t use index(idx_a)") + tk.MustExec("create global binding for select * from t using select * from t use index(idx_b)") + rows = tk.MustQuery("show global bindings").Rows() + c.Assert(len(rows), Equals, 2) + c.Assert(rows[0][1], Equals, "select * from t use index(idx_a)") + c.Assert(rows[1][1], Equals, "select * from t use index(idx_b)") + tk.MustExec("drop global binding for select * from t using select * from t use index(idx_a)") + rows = tk.MustQuery("show global bindings").Rows() + c.Assert(len(rows), Equals, 1) + c.Assert(rows[0][1], Equals, "select * from t use index(idx_b)") + tk.MustExec("drop global binding for select * from t using select * from t use index(idx_b)") + rows = tk.MustQuery("show global bindings").Rows() + c.Assert(len(rows), Equals, 0) +} diff --git a/bindinfo/cache.go b/bindinfo/cache.go index 1d08916a65..a6389a14fb 100644 --- a/bindinfo/cache.go +++ b/bindinfo/cache.go @@ -81,10 +81,10 @@ func (br *BindRecord) FindUsingBinding(hint string) *Binding { return nil } -func (br *BindRecord) prepareHintsForUsing(sctx sessionctx.Context, is infoschema.InfoSchema) error { +func (br *BindRecord) prepareHints(sctx sessionctx.Context, is infoschema.InfoSchema) error { p := parser.New() for i, bind := range br.Bindings { - if bind.Status != Using || bind.Hint != nil { + if bind.Hint != nil { continue } stmtNode, err := p.ParseOneStmt(bind.BindSQL, bind.Charset, bind.Collation) diff --git a/bindinfo/handle.go b/bindinfo/handle.go index db1753ac1d..4d9196d0d6 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -146,7 +146,7 @@ func (h *BindHandle) Update(fullLoad bool) (err error) { // AddBindRecord adds a BindRecord to the storage and BindRecord to the cache. func (h *BindHandle) AddBindRecord(sctx sessionctx.Context, is infoschema.InfoSchema, record *BindRecord) (err error) { - err = record.prepareHintsForUsing(sctx, is) + err = record.prepareHints(sctx, is) if err != nil { return err } @@ -179,10 +179,19 @@ func (h *BindHandle) AddBindRecord(sctx sessionctx.Context, is infoschema.InfoSc h.bindInfo.Unlock() }() - // remove all the unused sql binds. - _, err = exec.Execute(context.TODO(), h.deleteBindInfoSQL(record.OriginalSQL, record.Db)) - if err != nil { - return err + oldBindRecord := h.GetBindRecord(parser.DigestHash(record.OriginalSQL), record.OriginalSQL, record.Db) + if oldBindRecord != nil { + for _, newBinding := range record.Bindings { + binding := oldBindRecord.FindUsingBinding(newBinding.id) + if binding == nil { + continue + } + // Remove duplicates before insert. + _, err = exec.Execute(context.TODO(), h.deleteBindInfoSQL(record.OriginalSQL, record.Db, binding.BindSQL)) + if err != nil { + return err + } + } } txn, err1 := h.sctx.Context.Txn(true) @@ -208,7 +217,11 @@ func (h *BindHandle) AddBindRecord(sctx sessionctx.Context, is infoschema.InfoSc } // DropBindRecord drops a BindRecord to the storage and BindRecord int the cache. -func (h *BindHandle) DropBindRecord(record *BindRecord) (err error) { +func (h *BindHandle) DropBindRecord(sctx sessionctx.Context, is infoschema.InfoSchema, record *BindRecord) (err error) { + err = record.prepareHints(sctx, is) + if err != nil { + return err + } exec, _ := h.sctx.Context.(sqlexec.SQLExecutor) h.sctx.Lock() @@ -245,12 +258,21 @@ func (h *BindHandle) DropBindRecord(record *BindRecord) (err error) { Type: mysql.TypeDatetime, Fsp: 3, } + oldBindRecord := h.GetBindRecord(parser.DigestHash(record.OriginalSQL), record.OriginalSQL, record.Db) + bindingSQLs := make([]string, 0, len(record.Bindings)) for i := range record.Bindings { record.Bindings[i].Status = deleted record.Bindings[i].UpdateTime = updateTs + if oldBindRecord == nil { + continue + } + binding := oldBindRecord.FindUsingBinding(record.Bindings[i].id) + if binding != nil { + bindingSQLs = append(bindingSQLs, binding.BindSQL) + } } - _, err = exec.Execute(context.TODO(), h.logicalDeleteBindInfoSQL(record, updateTs)) + _, err = exec.Execute(context.TODO(), h.logicalDeleteBindInfoSQL(record.OriginalSQL, record.Db, updateTs, bindingSQLs)) return err } @@ -259,7 +281,7 @@ func (h *BindHandle) DropInvalidBindRecord() { invalidBindRecordMap := copyInvalidBindRecordMap(h.invalidBindRecordMap.Load().(map[string]*invalidBindRecordMap)) for key, invalidBindRecord := range invalidBindRecordMap { if invalidBindRecord.droppedTime.IsZero() { - err := h.DropBindRecord(invalidBindRecord.bindRecord) + err := h.DropBindRecord(nil, nil, invalidBindRecord.bindRecord) if err != nil { logutil.BgLogger().Error("DropInvalidBindRecord failed", zap.Error(err)) } @@ -339,7 +361,7 @@ func (h *BindHandle) newBindRecord(row chunk.Row) (string, *BindRecord, error) { return "", nil, err } h.sctx.GetSessionVars().CurrentDB = bindRecord.Db - err = bindRecord.prepareHintsForUsing(h.sctx.Context, h.sctx.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema)) + err = bindRecord.prepareHints(h.sctx.Context, h.sctx.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema)) return hash, bindRecord, err } @@ -349,7 +371,7 @@ func (h *BindHandle) appendBindRecord(hash string, meta *BindRecord) { newCache := h.bindInfo.Value.Load().(cache).copy() oldRecord := newCache.getBindRecord(hash, meta.OriginalSQL, meta.Db) newRecord := merge(oldRecord, meta) - newCache.setBindRecord(hash, meta) + newCache.setBindRecord(hash, newRecord) h.bindInfo.Value.Store(newCache) updateMetrics(metrics.ScopeGlobal, oldRecord, newRecord, false) } @@ -429,11 +451,12 @@ func (c cache) getBindRecord(hash, normdOrigSQL, db string) *BindRecord { return nil } -func (h *BindHandle) deleteBindInfoSQL(normdOrigSQL, db string) string { +func (h *BindHandle) deleteBindInfoSQL(normdOrigSQL, db, bindSQL string) string { return fmt.Sprintf( - `DELETE FROM mysql.bind_info WHERE original_sql=%s AND default_db=%s`, + `DELETE FROM mysql.bind_info WHERE original_sql=%s AND default_db=%s AND bind_sql = %s`, expression.Quote(normdOrigSQL), expression.Quote(db), + expression.Quote(bindSQL), ) } @@ -450,20 +473,19 @@ func (h *BindHandle) insertBindInfoSQL(orignalSQL string, db string, info Bindin ) } -func (h *BindHandle) logicalDeleteBindInfoSQL(record *BindRecord, updateTs types.Time) string { +func (h *BindHandle) logicalDeleteBindInfoSQL(originalSQL, db string, updateTs types.Time, bindingSQLs []string) string { sql := fmt.Sprintf(`UPDATE mysql.bind_info SET status=%s,update_time=%s WHERE original_sql=%s and default_db=%s`, expression.Quote(deleted), expression.Quote(updateTs.String()), - expression.Quote(record.OriginalSQL), - expression.Quote(record.Db)) - if len(record.Bindings) == 0 { + expression.Quote(originalSQL), + expression.Quote(db)) + if len(bindingSQLs) == 0 { return sql } - bindings := make([]string, 0, len(record.Bindings)) - for _, bind := range record.Bindings { - bindings = append(bindings, fmt.Sprintf(`%s`, expression.Quote(bind.BindSQL))) + for i, sql := range bindingSQLs { + bindingSQLs[i] = fmt.Sprintf(`%s`, expression.Quote(sql)) } - return sql + fmt.Sprintf(` and bind_sql in (%s)`, strings.Join(bindings, ",")) + return sql + fmt.Sprintf(` and bind_sql in (%s)`, strings.Join(bindingSQLs, ",")) } // GenHintsFromSQL is used to generate hints from SQL. diff --git a/bindinfo/session_handle.go b/bindinfo/session_handle.go index 321fa5ed8d..d7ccc08494 100644 --- a/bindinfo/session_handle.go +++ b/bindinfo/session_handle.go @@ -58,7 +58,7 @@ func (h *SessionHandle) AddBindRecord(sctx sessionctx.Context, is infoschema.Inf record.Bindings[i].UpdateTime = record.Bindings[i].CreateTime } - err := record.prepareHintsForUsing(sctx, is) + err := record.prepareHints(sctx, is) // update the BindMeta to the cache. if err == nil { h.appendBindRecord(parser.DigestHash(record.OriginalSQL), record) @@ -67,7 +67,11 @@ func (h *SessionHandle) AddBindRecord(sctx sessionctx.Context, is infoschema.Inf } // DropBindRecord drops a BindRecord in the cache. -func (h *SessionHandle) DropBindRecord(record *BindRecord) { +func (h *SessionHandle) DropBindRecord(sctx sessionctx.Context, is infoschema.InfoSchema, record *BindRecord) error { + err := record.prepareHints(sctx, is) + if err != nil { + return err + } oldRecord := h.GetBindRecord(record.OriginalSQL, record.Db) var newRecord *BindRecord if oldRecord != nil { @@ -77,6 +81,7 @@ func (h *SessionHandle) DropBindRecord(record *BindRecord) { } h.ch.setBindRecord(parser.DigestHash(record.OriginalSQL), newRecord) updateMetrics(metrics.ScopeSession, oldRecord, newRecord, false) + return nil } // GetBindRecord return the BindMeta of the (normdOrigSQL,db) if BindMeta exist. diff --git a/executor/bind.go b/executor/bind.go index 49c065da57..1b2ab0553c 100644 --- a/executor/bind.go +++ b/executor/bind.go @@ -55,12 +55,19 @@ func (e *SQLBindExec) dropSQLBind() error { OriginalSQL: e.normdOrigSQL, Db: e.ctx.GetSessionVars().CurrentDB, } + if e.bindSQL != "" { + bindInfo := bindinfo.Binding{ + BindSQL: e.bindSQL, + Charset: e.charset, + Collation: e.collation, + } + record.Bindings = append(record.Bindings, bindInfo) + } if !e.isGlobal { handle := e.ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) - handle.DropBindRecord(record) - return nil + return handle.DropBindRecord(e.ctx, GetInfoSchema(e.ctx), record) } - return domain.GetDomain(e.ctx).BindHandle().DropBindRecord(record) + return domain.GetDomain(e.ctx).BindHandle().DropBindRecord(e.ctx, GetInfoSchema(e.ctx), record) } func (e *SQLBindExec) createSQLBind() error { diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index bce9dbcc03..1db2033078 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -501,6 +501,9 @@ func (b *PlanBuilder) buildDropBindPlan(v *ast.DropBindingStmt) (Plan, error) { NormdOrigSQL: parser.Normalize(v.OriginSel.Text()), IsGlobal: v.GlobalScope, } + if v.HintedSel != nil { + p.BindSQL = v.HintedSel.Text() + } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SuperPriv, "", "", "", nil) return p, nil } diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 5abd950b66..b5608e7645 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -116,7 +116,11 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { case *ast.Join: p.checkNonUniqTableAlias(node) case *ast.CreateBindingStmt: - p.checkBindGrammar(node) + p.checkBindGrammar(node.OriginSel, node.HintedSel) + case *ast.DropBindingStmt: + if node.HintedSel != nil { + p.checkBindGrammar(node.OriginSel, node.HintedSel) + } case *ast.RecoverTableStmt: // The specified table in recover table statement maybe already been dropped. // So skip check table name here, otherwise, recover table [table_name] syntax will return @@ -128,9 +132,9 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { return in, p.err != nil } -func (p *preprocessor) checkBindGrammar(createBindingStmt *ast.CreateBindingStmt) { - originSQL := parser.Normalize(createBindingStmt.OriginSel.(*ast.SelectStmt).Text()) - hintedSQL := parser.Normalize(createBindingStmt.HintedSel.(*ast.SelectStmt).Text()) +func (p *preprocessor) checkBindGrammar(originSel, hintedSel ast.StmtNode) { + originSQL := parser.Normalize(originSel.(*ast.SelectStmt).Text()) + hintedSQL := parser.Normalize(hintedSel.(*ast.SelectStmt).Text()) if originSQL != hintedSQL { p.err = errors.Errorf("hinted sql and origin sql don't match when hinted sql erase the hint info, after erase hint info, originSQL:%s, hintedSQL:%s", originSQL, hintedSQL) diff --git a/planner/optimize.go b/planner/optimize.go index 7c48176c76..d8c9296b9b 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/logutil" ) // Optimize does optimization and creates a Plan. @@ -75,7 +76,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in plan, _, cost, err := optimize(ctx, sctx, node, is) if err != nil { binding.Status = bindinfo.Invalid - handleInvalidBindRecord(sctx, scope, bindinfo.BindRecord{ + handleInvalidBindRecord(ctx, sctx, scope, bindinfo.BindRecord{ OriginalSQL: bindRecord.OriginalSQL, Db: bindRecord.Db, Bindings: []bindinfo.Binding{binding}, @@ -189,9 +190,14 @@ func getBindRecord(ctx sessionctx.Context, stmt ast.StmtNode) (*bindinfo.BindRec return bindRecord, metrics.ScopeGlobal } -func handleInvalidBindRecord(sctx sessionctx.Context, level string, bindRecord bindinfo.BindRecord) { +func handleInvalidBindRecord(ctx context.Context, sctx sessionctx.Context, level string, bindRecord bindinfo.BindRecord) { sessionHandle := sctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) - sessionHandle.DropBindRecord(&bindRecord) + // The first two parameters are only used to generate hints, but since we already have the hints, + // we do not need to pass real values and the error won't happen too. + err := sessionHandle.DropBindRecord(nil, nil, &bindRecord) + if err != nil { + logutil.Logger(ctx).Info("drop session bindings failed") + } if level == metrics.ScopeSession { return }