Files
tidb/pkg/ttl/sqlbuilder/sql.go

483 lines
12 KiB
Go

// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlbuilder
import (
"encoding/hex"
"fmt"
"io"
"strconv"
"strings"
"time"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/meta/model"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/format"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/ttl/cache"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/sqlescape"
)
func writeHex(in io.Writer, d types.Datum) error {
_, err := fmt.Fprintf(in, "x'%s'", hex.EncodeToString(d.GetBytes()))
return err
}
func writeDatum(restoreCtx *format.RestoreCtx, d types.Datum, ft *types.FieldType) error {
switch ft.GetType() {
case mysql.TypeBit, mysql.TypeBlob, mysql.TypeLongBlob, mysql.TypeTinyBlob:
return writeHex(restoreCtx.In, d)
case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeEnum, mysql.TypeSet:
if mysql.HasBinaryFlag(ft.GetFlag()) {
return writeHex(restoreCtx.In, d)
}
_, err := fmt.Fprintf(restoreCtx.In, "'%s'", sqlescape.EscapeString(d.GetString()))
return err
}
expr := ast.NewValueExpr(d.GetValue(), ft.GetCharset(), ft.GetCollate())
return expr.Restore(restoreCtx)
}
// FormatSQLDatum formats the datum to a value string in sql
func FormatSQLDatum(d types.Datum, ft *types.FieldType) (string, error) {
var sb strings.Builder
ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)
if err := writeDatum(ctx, d, ft); err != nil {
return "", err
}
return sb.String(), nil
}
type sqlBuilderState int
const (
writeBegin sqlBuilderState = iota
writeSelOrDel
writeWhere
writeOrderBy
writeLimit
writeDone
)
// SQLBuilder is used to build SQLs for TTL
type SQLBuilder struct {
tbl *cache.PhysicalTable
sb strings.Builder
restoreCtx *format.RestoreCtx
state sqlBuilderState
isReadOnly bool
hasWriteExpireCond bool
}
// NewSQLBuilder creates a new TTLSQLBuilder
func NewSQLBuilder(tbl *cache.PhysicalTable) *SQLBuilder {
b := &SQLBuilder{tbl: tbl, state: writeBegin}
b.restoreCtx = format.NewRestoreCtx(format.DefaultRestoreFlags, &b.sb)
return b
}
// Build builds the final sql
func (b *SQLBuilder) Build() (string, error) {
if b.state == writeBegin {
return "", errors.Errorf("invalid state: %v", b.state)
}
if !b.isReadOnly && !b.hasWriteExpireCond {
// check whether the `timeRow < expire_time` condition has been written to make sure this SQL is safe.
return "", errors.New("expire condition not write")
}
if b.state != writeDone {
b.state = writeDone
}
return b.sb.String(), nil
}
// WriteSelect writes a select statement to select key columns without any condition
func (b *SQLBuilder) WriteSelect() error {
if b.state != writeBegin {
return errors.Errorf("invalid state: %v", b.state)
}
b.restoreCtx.WritePlain("SELECT LOW_PRIORITY SQL_NO_CACHE ")
b.writeColNames(b.tbl.KeyColumns, false)
b.restoreCtx.WritePlain(" FROM ")
if err := b.writeTblName(); err != nil {
return err
}
if par := b.tbl.PartitionDef; par != nil {
b.restoreCtx.WritePlain(" PARTITION(")
b.restoreCtx.WriteName(par.Name.O)
b.restoreCtx.WritePlain(")")
}
b.state = writeSelOrDel
b.isReadOnly = true
return nil
}
// WriteDelete writes a delete statement without any condition
func (b *SQLBuilder) WriteDelete() error {
if b.state != writeBegin {
return errors.Errorf("invalid state: %v", b.state)
}
b.restoreCtx.WritePlain("DELETE LOW_PRIORITY FROM ")
if err := b.writeTblName(); err != nil {
return err
}
if par := b.tbl.PartitionDef; par != nil {
b.restoreCtx.WritePlain(" PARTITION(")
b.restoreCtx.WriteName(par.Name.O)
b.restoreCtx.WritePlain(")")
}
b.state = writeSelOrDel
return nil
}
// WriteCommonCondition writes a new condition
func (b *SQLBuilder) WriteCommonCondition(cols []*model.ColumnInfo, op string, dp []types.Datum) error {
switch b.state {
case writeSelOrDel:
b.restoreCtx.WritePlain(" WHERE ")
b.state = writeWhere
case writeWhere:
b.restoreCtx.WritePlain(" AND ")
default:
return errors.Errorf("invalid state: %v", b.state)
}
b.writeColNames(cols, len(cols) > 1)
b.restoreCtx.WritePlain(" ")
b.restoreCtx.WritePlain(op)
b.restoreCtx.WritePlain(" ")
return b.writeDataPoint(cols, dp)
}
// WriteExpireCondition writes a condition with the time column
func (b *SQLBuilder) WriteExpireCondition(expire time.Time) error {
switch b.state {
case writeSelOrDel:
b.restoreCtx.WritePlain(" WHERE ")
b.state = writeWhere
case writeWhere:
b.restoreCtx.WritePlain(" AND ")
default:
return errors.Errorf("invalid state: %v", b.state)
}
b.writeColNames([]*model.ColumnInfo{b.tbl.TimeColumn}, false)
b.restoreCtx.WritePlain(" < ")
b.restoreCtx.WritePlain("FROM_UNIXTIME(")
b.restoreCtx.WritePlain(strconv.FormatInt(expire.Unix(), 10))
b.restoreCtx.WritePlain(")")
b.hasWriteExpireCond = true
return nil
}
// WriteInCondition writes an IN condition
func (b *SQLBuilder) WriteInCondition(cols []*model.ColumnInfo, dps ...[]types.Datum) error {
switch b.state {
case writeSelOrDel:
b.restoreCtx.WritePlain(" WHERE ")
b.state = writeWhere
case writeWhere:
b.restoreCtx.WritePlain(" AND ")
default:
return errors.Errorf("invalid state: %v", b.state)
}
b.writeColNames(cols, len(cols) > 1)
b.restoreCtx.WritePlain(" IN ")
b.restoreCtx.WritePlain("(")
first := true
for _, v := range dps {
if first {
first = false
} else {
b.restoreCtx.WritePlain(", ")
}
if err := b.writeDataPoint(cols, v); err != nil {
return err
}
}
b.restoreCtx.WritePlain(")")
return nil
}
// WriteOrderBy writes order by
func (b *SQLBuilder) WriteOrderBy(cols []*model.ColumnInfo, desc bool) error {
if b.state != writeSelOrDel && b.state != writeWhere {
return errors.Errorf("invalid state: %v", b.state)
}
b.state = writeOrderBy
b.restoreCtx.WritePlain(" ORDER BY ")
b.writeColNames(cols, false)
if desc {
b.restoreCtx.WritePlain(" DESC")
} else {
b.restoreCtx.WritePlain(" ASC")
}
return nil
}
// WriteLimit writes the limit
func (b *SQLBuilder) WriteLimit(n int) error {
if b.state != writeSelOrDel && b.state != writeWhere && b.state != writeOrderBy {
return errors.Errorf("invalid state: %v", b.state)
}
b.state = writeLimit
b.restoreCtx.WritePlain(" LIMIT ")
b.restoreCtx.WritePlain(strconv.Itoa(n))
return nil
}
func (b *SQLBuilder) writeTblName() error {
tn := ast.TableName{Schema: b.tbl.Schema, Name: b.tbl.Name}
return tn.Restore(b.restoreCtx)
}
func (b *SQLBuilder) writeColName(col *model.ColumnInfo) {
b.restoreCtx.WriteName(col.Name.O)
}
func (b *SQLBuilder) writeColNames(cols []*model.ColumnInfo, writeBrackets bool) {
if writeBrackets {
b.restoreCtx.WritePlain("(")
}
first := true
for _, col := range cols {
if first {
first = false
} else {
b.restoreCtx.WritePlain(", ")
}
b.writeColName(col)
}
if writeBrackets {
b.restoreCtx.WritePlain(")")
}
}
func (b *SQLBuilder) writeDataPoint(cols []*model.ColumnInfo, dp []types.Datum) error {
writeBrackets := len(cols) > 1
if len(cols) != len(dp) {
return errors.Errorf("col count not match %d != %d", len(cols), len(dp))
}
if writeBrackets {
b.restoreCtx.WritePlain("(")
}
first := true
for i, d := range dp {
if first {
first = false
} else {
b.restoreCtx.WritePlain(", ")
}
if err := writeDatum(b.restoreCtx, d, &cols[i].FieldType); err != nil {
return err
}
}
if writeBrackets {
b.restoreCtx.WritePlain(")")
}
return nil
}
// ScanQueryGenerator generates SQLs for scan task
type ScanQueryGenerator struct {
tbl *cache.PhysicalTable
expire time.Time
keyRangeStart []types.Datum
keyRangeEnd []types.Datum
stack [][]types.Datum
limit int
firstBuild bool
exhausted bool
}
// NewScanQueryGenerator creates a new ScanQueryGenerator
func NewScanQueryGenerator(tbl *cache.PhysicalTable, expire time.Time,
rangeStart, rangeEnd []types.Datum) (*ScanQueryGenerator, error) {
if err := tbl.ValidateKeyPrefix(rangeStart); err != nil {
return nil, err
}
if err := tbl.ValidateKeyPrefix(rangeEnd); err != nil {
return nil, err
}
return &ScanQueryGenerator{
tbl: tbl,
expire: expire,
keyRangeStart: rangeStart,
keyRangeEnd: rangeEnd,
firstBuild: true,
}, nil
}
// NextSQL creates next sql of the scan task
func (g *ScanQueryGenerator) NextSQL(continueFromResult [][]types.Datum, nextLimit int) (string, error) {
if g.exhausted {
return "", errors.New("generator is exhausted")
}
if nextLimit <= 0 {
return "", errors.Errorf("invalid limit '%d'", nextLimit)
}
defer func() {
g.firstBuild = false
}()
if g.stack == nil {
g.stack = make([][]types.Datum, 0, len(g.tbl.KeyColumns))
}
if len(continueFromResult) >= g.limit {
var continueFromKey []types.Datum
if cnt := len(continueFromResult); cnt > 0 {
continueFromKey = continueFromResult[cnt-1]
}
if err := g.setStack(continueFromKey); err != nil {
return "", err
}
} else {
if l := len(g.stack); l > 0 {
g.stack = g.stack[:l-1]
}
if len(g.stack) == 0 {
g.exhausted = true
}
}
g.limit = nextLimit
return g.buildSQL()
}
// IsExhausted returns whether the generator is exhausted
func (g *ScanQueryGenerator) IsExhausted() bool {
return g.exhausted
}
func (g *ScanQueryGenerator) setStack(key []types.Datum) error {
if key == nil {
key = g.keyRangeStart
}
if key == nil {
g.stack = g.stack[:0]
return nil
}
if err := g.tbl.ValidateKeyPrefix(key); err != nil {
return err
}
g.stack = g.stack[:len(key)]
for i := range key {
g.stack[i] = key[0 : i+1]
}
return nil
}
func (g *ScanQueryGenerator) buildSQL() (string, error) {
if g.limit <= 0 {
return "", errors.Errorf("invalid limit '%d'", g.limit)
}
if g.exhausted {
return "", nil
}
b := NewSQLBuilder(g.tbl)
if err := b.WriteSelect(); err != nil {
return "", err
}
if len(g.stack) > 0 {
for i, d := range g.stack[len(g.stack)-1] {
col := []*model.ColumnInfo{g.tbl.KeyColumns[i]}
val := []types.Datum{d}
var err error
if i < len(g.stack)-1 {
err = b.WriteCommonCondition(col, "=", val)
} else if g.firstBuild {
// When `g.firstBuild == true`, that means we are querying rows after range start, because range is defined
// as [start, end), we should use ">=" to find the rows including start key.
err = b.WriteCommonCondition(col, ">=", val)
} else {
// Otherwise when `g.firstBuild != true`, that means we are continuing with the previous result, we should use
// ">" to exclude the previous row.
err = b.WriteCommonCondition(col, ">", val)
}
if err != nil {
return "", err
}
}
}
if len(g.keyRangeEnd) > 0 {
if err := b.WriteCommonCondition(g.tbl.KeyColumns[0:len(g.keyRangeEnd)], "<", g.keyRangeEnd); err != nil {
return "", err
}
}
if err := b.WriteExpireCondition(g.expire); err != nil {
return "", err
}
if err := b.WriteOrderBy(g.tbl.KeyColumns, false); err != nil {
return "", err
}
if err := b.WriteLimit(g.limit); err != nil {
return "", err
}
return b.Build()
}
// BuildDeleteSQL builds a delete SQL
func BuildDeleteSQL(tbl *cache.PhysicalTable, rows [][]types.Datum, expire time.Time) (string, error) {
if len(rows) == 0 {
return "", errors.New("Cannot build delete SQL with empty rows")
}
b := NewSQLBuilder(tbl)
if err := b.WriteDelete(); err != nil {
return "", err
}
if err := b.WriteInCondition(tbl.KeyColumns, rows...); err != nil {
return "", err
}
if err := b.WriteExpireCondition(expire); err != nil {
return "", err
}
if err := b.WriteLimit(len(rows)); err != nil {
return "", err
}
return b.Build()
}