Files
tidb/lightning/pkg/server/lightning.go

1220 lines
36 KiB
Go

// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"cmp"
"compress/gzip"
"context"
"crypto/ecdsa"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/pprof"
"os"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/go-sql-driver/mysql"
"github.com/google/uuid"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/import_sstpb"
"github.com/pingcap/kvproto/pkg/metapb"
"github.com/pingcap/tidb/br/pkg/restore/split"
"github.com/pingcap/tidb/br/pkg/version/build"
"github.com/pingcap/tidb/lightning/pkg/importer"
"github.com/pingcap/tidb/lightning/pkg/importinto"
"github.com/pingcap/tidb/lightning/pkg/web"
_ "github.com/pingcap/tidb/pkg/expression" // get rid of `import cycle`: just init expression.RewriteAstExpr,and called at package `backend.kv`.
"github.com/pingcap/tidb/pkg/lightning/backend/local"
"github.com/pingcap/tidb/pkg/lightning/checkpoints"
"github.com/pingcap/tidb/pkg/lightning/common"
"github.com/pingcap/tidb/pkg/lightning/config"
"github.com/pingcap/tidb/pkg/lightning/log"
"github.com/pingcap/tidb/pkg/lightning/metric"
"github.com/pingcap/tidb/pkg/lightning/mydump"
"github.com/pingcap/tidb/pkg/lightning/tikv"
"github.com/pingcap/tidb/pkg/objstore"
"github.com/pingcap/tidb/pkg/objstore/storeapi"
_ "github.com/pingcap/tidb/pkg/planner/core" // init expression.EvalSimpleAst related function
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/promutil"
"github.com/pingcap/tidb/pkg/util/redact"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/shurcooL/httpgzip"
pdhttp "github.com/tikv/pd/client/http"
"go.uber.org/atomic"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// Lightning is the main struct of the lightning package.
type Lightning struct {
globalCfg *config.GlobalConfig
globalTLS *common.TLS
// taskCfgs is the list of task configurations enqueued in the server mode
taskCfgs *config.List
ctx context.Context
shutdown context.CancelFunc // for whole lightning context
server http.Server
serverAddr net.Addr
serverLock sync.Mutex
status importer.LightningStatus
promFactory promutil.Factory
promRegistry promutil.Registry
metrics *metric.Metrics
cancelLock sync.Mutex
curTask *config.Config
cancel context.CancelFunc // for per task context, which maybe different from lightning context
importer LightningImporter
taskCanceled bool
}
func initEnv(cfg *config.GlobalConfig) error {
if cfg.App.Config.File == "" {
return nil
}
return log.InitLogger(&cfg.App.Config, cfg.TiDB.LogLevel)
}
// New creates a new Lightning instance.
func New(globalCfg *config.GlobalConfig) *Lightning {
if err := initEnv(globalCfg); err != nil {
fmt.Println("Failed to initialize environment:", err)
os.Exit(1)
}
tls, err := common.NewTLS(
globalCfg.Security.CAPath,
globalCfg.Security.CertPath,
globalCfg.Security.KeyPath,
globalCfg.App.StatusAddr,
globalCfg.Security.CABytes,
globalCfg.Security.CertBytes,
globalCfg.Security.KeyBytes,
)
if err != nil {
log.L().Fatal("failed to load TLS certificates", zap.Error(err))
}
redact.InitRedact(globalCfg.Security.RedactInfoLog)
promFactory := promutil.NewDefaultFactory()
promRegistry := promutil.NewDefaultRegistry()
ctx, shutdown := context.WithCancel(context.Background())
return &Lightning{
globalCfg: globalCfg,
globalTLS: tls,
ctx: ctx,
shutdown: shutdown,
promFactory: promFactory,
promRegistry: promRegistry,
}
}
// GoServe starts the HTTP server in a goroutine. The server will be closed
func (l *Lightning) GoServe() error {
handleSigUsr1(func() {
l.serverLock.Lock()
statusAddr := l.globalCfg.App.StatusAddr
shouldStartServer := len(statusAddr) == 0
if shouldStartServer {
l.globalCfg.App.StatusAddr = ":"
}
l.serverLock.Unlock()
if shouldStartServer {
// open a random port and start the server if SIGUSR1 is received.
if err := l.goServe(":", os.Stderr); err != nil {
log.L().Warn("failed to start HTTP server", log.ShortError(err))
}
} else {
// just prints the server address if it is already started.
log.L().Info("already started HTTP server", zap.Stringer("address", l.serverAddr))
}
})
l.serverLock.Lock()
statusAddr := l.globalCfg.App.StatusAddr
l.serverLock.Unlock()
if len(statusAddr) == 0 {
return nil
}
return l.goServe(statusAddr, io.Discard)
}
// TODO: maybe handle http request using gin
type loggingResponseWriter struct {
http.ResponseWriter
statusCode int
body string
}
func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter {
return &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
}
// WriteHeader implements http.ResponseWriter.
func (lrw *loggingResponseWriter) WriteHeader(code int) {
lrw.statusCode = code
lrw.ResponseWriter.WriteHeader(code)
}
// Write implements http.ResponseWriter.
func (lrw *loggingResponseWriter) Write(d []byte) (int, error) {
// keep first part of the response for logging, max 1K
if lrw.body == "" && len(d) > 0 {
length := min(len(d), 1024)
lrw.body = string(d[:length])
}
return lrw.ResponseWriter.Write(d)
}
func httpHandleWrapper(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
logger := log.L().With(zap.String("method", r.Method), zap.Stringer("url", r.URL)).
Begin(zapcore.InfoLevel, "process http request")
newWriter := newLoggingResponseWriter(w)
h.ServeHTTP(newWriter, r)
bodyField := zap.Skip()
if newWriter.Header().Get("Content-Encoding") != "gzip" {
bodyField = zap.String("body", newWriter.body)
}
logger.End(zapcore.InfoLevel, nil, zap.Int("status", newWriter.statusCode), bodyField)
}
}
func (l *Lightning) goServe(statusAddr string, realAddrWriter io.Writer) error {
mux := http.NewServeMux()
mux.Handle("/", http.RedirectHandler("/web/", http.StatusFound))
registry := l.promRegistry
registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
registry.MustRegister(collectors.NewGoCollector())
if gatherer, ok := registry.(prometheus.Gatherer); ok {
handler := promhttp.InstrumentMetricHandler(
registry, promhttp.HandlerFor(gatherer, promhttp.HandlerOpts{}),
)
mux.Handle("/metrics", handler)
}
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
// Enable failpoint http API for testing.
failpoint.Inject("EnableTestAPI", func() {
mux.HandleFunc("/fail/", func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail")
new(failpoint.HttpHandler).ServeHTTP(w, r)
})
})
handleTasks := http.StripPrefix("/tasks", http.HandlerFunc(l.handleTask))
mux.Handle("/tasks", httpHandleWrapper(handleTasks.ServeHTTP))
mux.Handle("/tasks/", httpHandleWrapper(handleTasks.ServeHTTP))
mux.HandleFunc("/progress/task", httpHandleWrapper(handleProgressTask))
mux.HandleFunc("/progress/table", httpHandleWrapper(handleProgressTable))
mux.HandleFunc("/pause", httpHandleWrapper(l.handlePause))
mux.HandleFunc("/resume", httpHandleWrapper(l.handleResume))
mux.HandleFunc("/loglevel", httpHandleWrapper(handleLogLevel))
mux.Handle("/web/", http.StripPrefix("/web", httpgzip.FileServer(web.Res, httpgzip.FileServerOptions{
IndexHTML: true,
ServeError: func(w http.ResponseWriter, req *http.Request, err error) {
if os.IsNotExist(err) && !strings.Contains(req.URL.Path, ".") {
http.Redirect(w, req, "/web/", http.StatusFound)
} else {
httpgzip.NonSpecific(w, req, err)
}
},
})))
listener, err := net.Listen("tcp", statusAddr)
if err != nil {
return err
}
l.serverAddr = listener.Addr()
log.L().Info("starting HTTP server", zap.Stringer("address", l.serverAddr))
fmt.Fprintln(realAddrWriter, "started HTTP server on", l.serverAddr)
l.server.Handler = mux
listener = l.globalTLS.WrapListener(listener)
go func() {
err := l.server.Serve(listener)
log.L().Info("stopped HTTP server", log.ShortError(err))
}()
return nil
}
// RunServer is used by binary lightning to start a HTTP server to receive import tasks.
func (l *Lightning) RunServer() error {
l.serverLock.Lock()
l.taskCfgs = config.NewConfigList()
l.serverLock.Unlock()
log.L().Info(
"Lightning server is running, post to /tasks to start an import task",
zap.Stringer("address", l.serverAddr),
)
for {
task, err := l.taskCfgs.Pop(l.ctx)
if err != nil {
return err
}
o := &options{
promFactory: l.promFactory,
promRegistry: l.promRegistry,
logger: log.L(),
}
err = l.run(context.Background(), task, o)
if err != nil && !common.IsContextCanceledError(err) {
importer.DeliverPauser.Pause() // force pause the progress on error
log.L().Error("tidb lightning encountered error", zap.Error(err))
}
}
}
// RunOnceWithOptions is used by binary lightning and host when using lightning as a library.
// - for binary lightning, taskCtx could be context.Background which means taskCtx wouldn't be canceled directly by its
// cancel function, but only by Lightning.Stop or HTTP DELETE using l.cancel. No need to set Options
// - for lightning as a library, taskCtx could be a meaningful context that get canceled outside, and there Options may
// be used:
// - WithGlue: set a caller implemented glue. Otherwise, lightning will use a default glue later.
// - WithDumpFileStorage: caller has opened an external storage for lightning. Otherwise, lightning will open a
// storage by config
// - WithCheckpointStorage: caller has opened an external storage for lightning and want to save checkpoint
// in it. Otherwise, lightning will save checkpoint by the Checkpoint.DSN in config
func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config.Config, opts ...Option) error {
o := &options{
promFactory: l.promFactory,
promRegistry: l.promRegistry,
logger: log.L(),
}
for _, opt := range opts {
opt(o)
}
failpoint.Inject("setExtStorage", func(val failpoint.Value) {
path := val.(string)
b, err := objstore.ParseBackend(path, nil)
if err != nil {
panic(err)
}
s, err := objstore.New(context.Background(), b, &storeapi.Options{})
if err != nil {
panic(err)
}
o.dumpFileStorage = s
o.checkpointStorage = s
})
failpoint.Inject("setCheckpointName", func(val failpoint.Value) {
file := val.(string)
o.checkpointName = file
})
if o.dumpFileStorage != nil {
// we don't use it, set a value to pass Adjust
taskCfg.Mydumper.SourceDir = "noop://"
}
if err := taskCfg.Adjust(taskCtx); err != nil {
return err
}
r, err := rand.Int(rand.Reader, big.NewInt(1<<63-1))
if err != nil {
return err
}
taskCfg.TaskID = r.Int64()
failpoint.Inject("SetTaskID", func(val failpoint.Value) {
taskCfg.TaskID = int64(val.(int))
})
failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) {
o.logger.Info("set io total bytes")
taskCfg.TiDB.IOTotalBytes = atomic.NewUint64(0)
taskCfg.TiDB.UUID = uuid.New().String()
go func() {
for {
time.Sleep(time.Millisecond * 10)
log.L().Info("IOTotalBytes", zap.Uint64("IOTotalBytes", taskCfg.TiDB.IOTotalBytes.Load()))
}
}()
})
if taskCfg.TiDB.IOTotalBytes != nil {
o.logger.Info("found IO total bytes counter")
mysql.RegisterDialContext(taskCfg.TiDB.UUID, func(ctx context.Context, addr string) (net.Conn, error) {
o.logger.Debug("connection with IO bytes counter")
d := &net.Dialer{}
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
tcpConn := conn.(*net.TCPConn)
// try https://github.com/go-sql-driver/mysql/blob/bcc459a906419e2890a50fc2c99ea6dd927a88f2/connector.go#L56-L64
err = tcpConn.SetKeepAlive(true)
if err != nil {
o.logger.Warn("set TCP keep alive failed", zap.Error(err))
}
return util.NewTCPConnWithIOCounter(tcpConn, taskCfg.TiDB.IOTotalBytes), nil
})
}
return l.run(taskCtx, taskCfg, o)
}
var (
taskRunNotifyKey = "taskRunNotifyKey"
taskCfgRecorderKey = "taskCfgRecorderKey"
)
func getKeyspaceName(db *sql.DB) (string, error) {
if db == nil {
return "", nil
}
rows, err := db.Query("show config where Type = 'tidb' and name = 'keyspace-name'")
if err != nil {
return "", err
}
//nolint: errcheck
defer rows.Close()
var (
_type string
_instance string
_name string
value string
)
if rows.Next() {
err = rows.Scan(&_type, &_instance, &_name, &value)
if err != nil {
return "", err
}
}
return value, rows.Err()
}
func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *options) (err error) {
build.LogInfo(build.Lightning)
o.logger.Info("cfg", zap.String("cfg", taskCfg.Redact()))
logutil.LogEnvVariables()
if split.WaitRegionOnlineAttemptTimes != taskCfg.TikvImporter.RegionCheckBackoffLimit {
// it will cause data race if lightning is used as a library, but this is a
// hidden config so we ignore that case
split.WaitRegionOnlineAttemptTimes = taskCfg.TikvImporter.RegionCheckBackoffLimit
}
metrics := metric.NewMetrics(o.promFactory)
metrics.RegisterTo(o.promRegistry)
defer func() {
metrics.UnregisterFrom(o.promRegistry)
}()
l.metrics = metrics
ctx := metric.WithMetric(taskCtx, metrics)
ctx = logutil.WithLogger(ctx, o.logger.Logger)
ctx, cancel := context.WithCancel(ctx)
l.cancelLock.Lock()
l.cancel = cancel
l.curTask = taskCfg
l.cancelLock.Unlock()
web.BroadcastStartTask()
defer func() {
cancel()
l.cancelLock.Lock()
l.cancel = nil
l.importer = nil
l.cancelLock.Unlock()
web.BroadcastEndTask(err)
}()
failpoint.Inject("SkipRunTask", func() {
if notifyCh, ok := l.ctx.Value(taskRunNotifyKey).(chan struct{}); ok {
select {
case notifyCh <- struct{}{}:
default:
}
}
if recorder, ok := l.ctx.Value(taskCfgRecorderKey).(chan *config.Config); ok {
select {
case recorder <- taskCfg:
case <-ctx.Done():
failpoint.Return(ctx.Err())
}
}
failpoint.Return(nil)
})
failpoint.Inject("SetCertExpiredSoon", func(val failpoint.Value) {
rootKeyPath := val.(string)
rootCaPath := taskCfg.Security.CAPath
keyPath := taskCfg.Security.KeyPath
certPath := taskCfg.Security.CertPath
if err := updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath, time.Second*10); err != nil {
panic(err)
}
})
failpoint.Inject("PrintStatus", func() {
defer func() {
finished, total := l.Status()
o.logger.Warn("PrintStatus Failpoint",
zap.Int64("finished", finished),
zap.Int64("total", total),
zap.Bool("equal", finished == total))
}()
})
if err := taskCfg.TiDB.Security.BuildTLSConfig(); err != nil {
return common.ErrInvalidTLSConfig.Wrap(err)
}
var mdl *mydump.MDLoader
var dbMetas []*mydump.MDDatabaseMeta
s := o.dumpFileStorage
if taskCfg.TikvImporter.Backend != config.BackendImportInto {
mdl, s, err = l.initDataSource(ctx, taskCfg, o)
if err != nil {
return err
}
dbMetas = mdl.GetDatabases()
web.BroadcastInitProgress(dbMetas)
}
db, keyspaceName, err := initDBAndKeyspace(ctx, taskCfg, o)
if err != nil {
return err
}
param := &importer.ControllerParam{
DBMetas: dbMetas,
Status: &l.status,
DumpFileStorage: s,
OwnExtStorage: o.dumpFileStorage == nil,
DB: db,
CheckpointStorage: o.checkpointStorage,
CheckpointName: o.checkpointName,
DupIndicator: o.dupIndicator,
KeyspaceName: keyspaceName,
}
var procedure LightningImporter
procedure, err = newImporter(ctx, taskCfg, param)
if err != nil {
o.logger.Error("restore failed", log.ShortError(err))
return errors.Trace(err)
}
l.cancelLock.Lock()
l.importer = procedure
l.cancelLock.Unlock()
failpoint.Inject("orphanWriterGoRoutine", func() {
// don't exit too quickly to expose panic
defer time.Sleep(time.Second * 10)
})
defer procedure.Close()
err = procedure.Run(ctx)
return errors.Trace(err)
}
func (l *Lightning) initDataSource(ctx context.Context, taskCfg *config.Config, o *options) (*mydump.MDLoader, storeapi.Storage, error) {
s := o.dumpFileStorage
if s == nil {
u, err := objstore.ParseBackend(taskCfg.Mydumper.SourceDir, nil)
if err != nil {
return nil, nil, common.NormalizeError(err)
}
s, err = objstore.New(ctx, u, &storeapi.Options{})
if err != nil {
return nil, nil, common.NormalizeError(err)
}
}
// return expectedErr means at least meet one file
expectedErr := errors.New("Stop Iter")
walkErr := s.WalkDir(ctx, &storeapi.WalkOption{ListCount: 1}, func(string, int64) error {
// return an error when meet the first regular file to break the walk loop
return expectedErr
})
if !errors.ErrorEqual(walkErr, expectedErr) {
if walkErr == nil {
return nil, nil, common.ErrEmptySourceDir.GenWithStackByArgs(taskCfg.Mydumper.SourceDir)
}
return nil, nil, common.NormalizeOrWrapErr(common.ErrStorageUnknown, walkErr)
}
loadTask := o.logger.Begin(zap.InfoLevel, "load data source")
mdl, err := mydump.NewLoaderWithStore(
ctx, mydump.NewLoaderCfg(taskCfg), s,
mydump.WithScanFileConcurrency(l.curTask.App.RegionConcurrency*2),
)
loadTask.End(zap.ErrorLevel, err)
if err != nil {
return nil, nil, errors.Trace(err)
}
err = checkSystemRequirement(taskCfg, mdl.GetDatabases())
if err != nil {
o.logger.Error("check system requirements failed", zap.Error(err))
return nil, nil, common.ErrSystemRequirementNotMet.Wrap(err).GenWithStackByArgs()
}
// check table schema conflicts
err = checkSchemaConflict(taskCfg, mdl.GetDatabases())
if err != nil {
o.logger.Error("checkpoint schema conflicts with data files", zap.Error(err))
return nil, nil, errors.Trace(err)
}
return mdl, s, nil
}
func initDBAndKeyspace(ctx context.Context, taskCfg *config.Config, o *options) (*sql.DB, string, error) {
// db is only not nil in unit test
db := o.db
var err error
if db == nil {
// initiation of default db should be after BuildTLSConfig
db, err = importer.DBFromConfig(ctx, taskCfg.TiDB)
if err != nil {
return nil, "", common.ErrDBConnect.Wrap(err)
}
}
var keyspaceName string
if taskCfg.TikvImporter.Backend == config.BackendLocal {
keyspaceName = taskCfg.TikvImporter.KeyspaceName
if keyspaceName == "" {
keyspaceName, err = getKeyspaceName(db)
if err != nil && common.IsAccessDeniedNeedConfigPrivilegeError(err) {
// if the cluster is not multitenant we don't really need to know about the keyspace.
// since the doc does not say we require CONFIG privilege,
// spelling out the Access Denied error just confuses the users.
// hide such allowed errors unless log level is DEBUG.
o.logger.Info("keyspace is unspecified and target user has no config privilege, assuming dedicated cluster")
if o.logger.Level() > zapcore.DebugLevel {
err = nil
}
}
if err != nil {
o.logger.Warn("unable to get keyspace name, lightning will use empty keyspace name", zap.Error(err))
}
}
o.logger.Info("acquired keyspace name", zap.String("keyspaceName", keyspaceName))
}
return db, keyspaceName, nil
}
// Stop stops the lightning server.
func (l *Lightning) Stop() {
l.cancelLock.Lock()
if l.cancel != nil {
l.taskCanceled = true
l.cancel()
}
l.cancelLock.Unlock()
if err := l.server.Shutdown(l.ctx); err != nil {
log.L().Warn("failed to shutdown HTTP server", log.ShortError(err))
}
l.shutdown()
}
// TaskCanceled return whether the current task is canceled.
func (l *Lightning) TaskCanceled() bool {
l.cancelLock.Lock()
defer l.cancelLock.Unlock()
return l.taskCanceled
}
// Status return the sum size of file which has been imported to TiKV and the total size of source file.
func (l *Lightning) Status() (finished int64, total int64) {
finished = l.status.FinishedFileSize.Load()
total = l.status.TotalFileSize.Load()
return
}
// Metrics returns the metrics of lightning.
// it's inited during `run`, so might return nil.
func (l *Lightning) Metrics() *metric.Metrics {
return l.metrics
}
func writeJSONError(w http.ResponseWriter, code int, prefix string, err error) {
type errorResponse struct {
Error string `json:"error"`
}
w.WriteHeader(code)
if err != nil {
prefix += ": " + err.Error()
}
_ = json.NewEncoder(w).Encode(errorResponse{Error: prefix})
}
func parseTaskID(req *http.Request) (int64, string, error) {
path := strings.TrimPrefix(req.URL.Path, "/")
taskIDString := path
verb := ""
if i := strings.IndexByte(path, '/'); i >= 0 {
taskIDString = path[:i]
verb = path[i+1:]
}
taskID, err := strconv.ParseInt(taskIDString, 10, 64)
if err != nil {
return 0, "", err
}
return taskID, verb, nil
}
func (l *Lightning) handleTask(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch req.Method {
case http.MethodGet:
taskID, _, err := parseTaskID(req)
// golint tells us to refactor this with switch stmt.
// However switch stmt doesn't support init-statements,
// hence if we follow it things might be worse.
// Anyway, this chain of if-else isn't unacceptable.
//nolint:gocritic
if e, ok := err.(*strconv.NumError); ok && e.Num == "" {
l.handleGetTask(w)
} else if err == nil {
l.handleGetOneTask(w, req, taskID)
} else {
writeJSONError(w, http.StatusBadRequest, "invalid task ID", err)
}
case http.MethodPost:
l.handlePostTask(w, req)
case http.MethodDelete:
l.handleDeleteOneTask(w, req)
case http.MethodPatch:
l.handlePatchOneTask(w, req)
default:
w.Header().Set("Allow", http.MethodGet+", "+http.MethodPost+", "+http.MethodDelete+", "+http.MethodPatch)
writeJSONError(w, http.StatusMethodNotAllowed, "only GET, POST, DELETE and PATCH are allowed", nil)
}
}
func (l *Lightning) handleGetTask(w http.ResponseWriter) {
var response struct {
Current *int64 `json:"current"`
QueuedIDs []int64 `json:"queue"`
}
l.serverLock.Lock()
if l.taskCfgs != nil {
response.QueuedIDs = l.taskCfgs.AllIDs()
} else {
response.QueuedIDs = []int64{}
}
l.serverLock.Unlock()
l.cancelLock.Lock()
if l.cancel != nil && l.curTask != nil {
response.Current = new(int64)
*response.Current = l.curTask.TaskID
}
l.cancelLock.Unlock()
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(response)
}
func (l *Lightning) handleGetOneTask(w http.ResponseWriter, req *http.Request, taskID int64) {
var task *config.Config
l.cancelLock.Lock()
if l.curTask != nil && l.curTask.TaskID == taskID {
task = l.curTask
}
l.cancelLock.Unlock()
if task == nil && l.taskCfgs != nil {
task, _ = l.taskCfgs.Get(taskID)
}
if task == nil {
writeJSONError(w, http.StatusNotFound, "task ID not found", nil)
return
}
json, err := json.Marshal(task)
if err != nil {
writeJSONError(w, http.StatusInternalServerError, "unable to serialize task", err)
return
}
writeBytesCompressed(w, req, json)
}
func (l *Lightning) handlePostTask(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Cache-Control", "no-store")
l.serverLock.Lock()
defer l.serverLock.Unlock()
if l.taskCfgs == nil {
// l.taskCfgs is non-nil only if Lightning is started with RunServer().
// Without the server mode this pointer is default to be nil.
writeJSONError(w, http.StatusNotImplemented, "server-mode not enabled", nil)
return
}
type taskResponse struct {
ID int64 `json:"id"`
}
data, err := io.ReadAll(req.Body)
if err != nil {
writeJSONError(w, http.StatusBadRequest, "cannot read request", err)
return
}
log.L().Info("received task config")
cfg := config.NewConfig()
if err = cfg.LoadFromGlobal(l.globalCfg); err != nil {
writeJSONError(w, http.StatusInternalServerError, "cannot restore from global config", err)
return
}
if err = cfg.LoadFromTOML(data); err != nil {
writeJSONError(w, http.StatusBadRequest, "cannot parse task (must be TOML)", err)
return
}
if err = cfg.Adjust(l.ctx); err != nil {
writeJSONError(w, http.StatusBadRequest, "invalid task configuration", err)
return
}
l.taskCfgs.Push(cfg)
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(taskResponse{ID: cfg.TaskID})
}
func (l *Lightning) handleDeleteOneTask(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
taskID, _, err := parseTaskID(req)
if err != nil {
writeJSONError(w, http.StatusBadRequest, "invalid task ID", err)
return
}
var cancel context.CancelFunc
cancelSuccess := false
l.cancelLock.Lock()
if l.cancel != nil && l.curTask != nil && l.curTask.TaskID == taskID {
cancel = l.cancel
l.cancel = nil
}
l.cancelLock.Unlock()
if cancel != nil {
cancel()
cancelSuccess = true
} else if l.taskCfgs != nil {
cancelSuccess = l.taskCfgs.Remove(taskID)
}
log.L().Info("canceled task", zap.Int64("taskID", taskID), zap.Bool("success", cancelSuccess))
if cancelSuccess {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{}"))
} else {
writeJSONError(w, http.StatusNotFound, "task ID not found", nil)
}
}
func (l *Lightning) handlePatchOneTask(w http.ResponseWriter, req *http.Request) {
if l.taskCfgs == nil {
writeJSONError(w, http.StatusNotImplemented, "server-mode not enabled", nil)
return
}
taskID, verb, err := parseTaskID(req)
if err != nil {
writeJSONError(w, http.StatusBadRequest, "invalid task ID", err)
return
}
moveSuccess := false
switch verb {
case "front":
moveSuccess = l.taskCfgs.MoveToFront(taskID)
case "back":
moveSuccess = l.taskCfgs.MoveToBack(taskID)
default:
writeJSONError(w, http.StatusBadRequest, "unknown patch action", nil)
return
}
if moveSuccess {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{}"))
} else {
writeJSONError(w, http.StatusNotFound, "task ID not found", nil)
}
}
func writeBytesCompressed(w http.ResponseWriter, req *http.Request, b []byte) {
if !strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") {
_, _ = w.Write(b)
return
}
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusOK)
gw, _ := gzip.NewWriterLevel(w, gzip.BestSpeed)
_, _ = gw.Write(b)
_ = gw.Close()
}
func handleProgressTask(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
res, err := web.MarshalTaskProgress()
if err == nil {
writeBytesCompressed(w, req, res)
} else {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(err.Error())
}
}
func handleProgressTable(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
tableName := req.URL.Query().Get("t")
res, err := web.MarshalTableCheckpoints(tableName)
if err == nil {
writeBytesCompressed(w, req, res)
} else {
if errors.IsNotFound(err) {
w.WriteHeader(http.StatusNotFound)
} else {
w.WriteHeader(http.StatusInternalServerError)
}
_ = json.NewEncoder(w).Encode(err.Error())
}
}
func (l *Lightning) handlePause(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch req.Method {
case http.MethodGet:
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"paused":%v}`, importer.DeliverPauser.IsPaused())
case http.MethodPut:
w.WriteHeader(http.StatusOK)
l.cancelLock.Lock()
imp := l.importer
l.cancelLock.Unlock()
if imp != nil {
if err := imp.Pause(req.Context()); err != nil {
log.L().Error("failed to pause", zap.Error(err))
writeJSONError(w, http.StatusInternalServerError, "failed to pause", err)
return
}
log.L().Info("progress paused")
} else {
importer.DeliverPauser.Pause()
log.L().Info("progress paused")
}
_, _ = w.Write([]byte("{}"))
default:
w.Header().Set("Allow", http.MethodGet+", "+http.MethodPut)
writeJSONError(w, http.StatusMethodNotAllowed, "only GET and PUT are allowed", nil)
}
}
func (l *Lightning) handleResume(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch req.Method {
case http.MethodPut:
w.WriteHeader(http.StatusOK)
l.cancelLock.Lock()
imp := l.importer
l.cancelLock.Unlock()
if imp != nil {
if err := imp.Resume(req.Context()); err != nil {
log.L().Error("failed to resume", zap.Error(err))
writeJSONError(w, http.StatusInternalServerError, "failed to resume", err)
return
}
log.L().Info("progress resumed")
} else {
importer.DeliverPauser.Resume()
log.L().Info("progress resumed")
}
_, _ = w.Write([]byte("{}"))
default:
w.Header().Set("Allow", http.MethodPut)
writeJSONError(w, http.StatusMethodNotAllowed, "only PUT is allowed", nil)
}
}
func handleLogLevel(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
var logLevel struct {
Level zapcore.Level `json:"level"`
}
switch req.Method {
case http.MethodGet:
logLevel.Level = log.Level()
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(logLevel)
case http.MethodPut, http.MethodPost:
if err := json.NewDecoder(req.Body).Decode(&logLevel); err != nil {
writeJSONError(w, http.StatusBadRequest, "invalid log level", err)
return
}
oldLevel := log.SetLevel(zapcore.InfoLevel)
log.L().Info("changed log level. No effects if task has specified its logger",
zap.Stringer("old", oldLevel),
zap.Stringer("new", logLevel.Level))
log.SetLevel(logLevel.Level)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{}"))
default:
w.Header().Set("Allow", http.MethodGet+", "+http.MethodPut+", "+http.MethodPost)
writeJSONError(w, http.StatusMethodNotAllowed, "only GET, PUT and POST are allowed", nil)
}
}
func checkSystemRequirement(cfg *config.Config, dbsMeta []*mydump.MDDatabaseMeta) error {
// in local mode, we need to read&write a lot of L0 sst files, so we need to check system max open files limit
if cfg.TikvImporter.Backend == config.BackendLocal {
// estimate max open files = {top N(TableConcurrency) table sizes} / {MemoryTableSize}
tableTotalSizes := make([]int64, 0)
for _, dbs := range dbsMeta {
for _, tb := range dbs.Tables {
tableTotalSizes = append(tableTotalSizes, tb.TotalSize)
}
}
slices.SortFunc(tableTotalSizes, func(i, j int64) int {
return cmp.Compare(j, i)
})
topNTotalSize := int64(0)
for i := 0; i < len(tableTotalSizes) && i < cfg.App.TableConcurrency; i++ {
topNTotalSize += tableTotalSizes[i]
}
// region-concurrency: number of LocalWriters writing SST files.
// 2*totalSize/memCacheSize: number of Pebble MemCache files.
maxDBFiles := topNTotalSize / int64(cfg.TikvImporter.LocalWriterMemCacheSize) * 2
// the pebble db and all import routine need upto maxDBFiles fds for read and write.
maxOpenDBFiles := maxDBFiles * (1 + int64(cfg.TikvImporter.RangeConcurrency))
estimateMaxFiles := local.RlimT(cfg.App.RegionConcurrency) + local.RlimT(maxOpenDBFiles)
if err := local.VerifyRLimit(estimateMaxFiles); err != nil {
return err
}
}
return nil
}
// checkSchemaConflict return error if checkpoint table scheme is conflict with data files
func checkSchemaConflict(cfg *config.Config, dbsMeta []*mydump.MDDatabaseMeta) error {
if cfg.Checkpoint.Enable && cfg.Checkpoint.Driver == config.CheckpointDriverMySQL {
for _, db := range dbsMeta {
if db.Name == cfg.Checkpoint.Schema {
for _, tb := range db.Tables {
if checkpoints.IsCheckpointTable(tb.Name) {
return common.ErrCheckpointSchemaConflict.GenWithStack("checkpoint table `%s`.`%s` conflict with data files. Please change the `checkpoint.schema` config or set `checkpoint.driver` to \"file\" instead", db.Name, tb.Name)
}
}
}
}
}
return nil
}
// CleanupMetas removes the table metas of the given table.
func CleanupMetas(ctx context.Context, cfg *config.Config, tableName string) error {
if tableName == common.AllTables {
tableName = ""
}
// try to clean up table metas if exists
db, err := importer.DBFromConfig(ctx, cfg.TiDB)
if err != nil {
return errors.Trace(err)
}
tableMetaExist, err := common.TableExists(ctx, db, cfg.App.MetaSchemaName, importer.TableMetaTableName)
if err != nil {
return errors.Trace(err)
}
if tableMetaExist {
metaTableName := common.UniqueTable(cfg.App.MetaSchemaName, importer.TableMetaTableName)
if err = importer.RemoveTableMetaByTableName(ctx, db, metaTableName, tableName); err != nil {
return errors.Trace(err)
}
}
exist, err := common.TableExists(ctx, db, cfg.App.MetaSchemaName, importer.TaskMetaTableName)
if err != nil || !exist {
return errors.Trace(err)
}
return errors.Trace(importer.MaybeCleanupAllMetas(ctx, log.L(), db, cfg.App.MetaSchemaName, tableMetaExist))
}
// SwitchMode switches the mode of the TiKV cluster.
func SwitchMode(ctx context.Context, cli pdhttp.Client, tls *tls.Config, mode string, ranges ...*import_sstpb.Range) error {
var m import_sstpb.SwitchMode
switch mode {
case config.ImportMode:
m = import_sstpb.SwitchMode_Import
case config.NormalMode:
m = import_sstpb.SwitchMode_Normal
default:
return errors.Errorf("invalid mode %s, must use %s or %s", mode, config.ImportMode, config.NormalMode)
}
return tikv.ForAllStores(
ctx,
cli,
metapb.StoreState_Offline,
func(c context.Context, store *pdhttp.MetaStore) error {
return tikv.SwitchMode(c, tls, store.Address, m, ranges...)
},
)
}
func updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath string, expiry time.Duration) error {
rootKey, err := parsePrivateKey(rootKeyPath)
if err != nil {
return err
}
rootCaPem, err := os.ReadFile(rootCaPath)
if err != nil {
return err
}
rootCaDer, _ := pem.Decode(rootCaPem)
rootCa, err := x509.ParseCertificate(rootCaDer.Bytes)
if err != nil {
return err
}
key, err := parsePrivateKey(keyPath)
if err != nil {
return err
}
certPem, err := os.ReadFile(certPath)
if err != nil {
panic(err)
}
certDer, _ := pem.Decode(certPem)
cert, err := x509.ParseCertificate(certDer.Bytes)
if err != nil {
return err
}
cert.NotBefore = time.Now()
cert.NotAfter = time.Now().Add(expiry)
derBytes, err := x509.CreateCertificate(rand.Reader, cert, rootCa, &key.PublicKey, rootKey)
if err != nil {
return err
}
return os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), 0o600)
}
func parsePrivateKey(keyPath string) (*ecdsa.PrivateKey, error) {
keyPemBlock, err := os.ReadFile(keyPath)
if err != nil {
return nil, err
}
var keyDERBlock *pem.Block
for {
keyDERBlock, keyPemBlock = pem.Decode(keyPemBlock)
if keyDERBlock == nil {
return nil, errors.New("failed to find PEM block with type ending in \"PRIVATE KEY\"")
}
if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
break
}
}
return x509.ParseECPrivateKey(keyDERBlock.Bytes)
}
// LightningImporter is an interface that abstracts the execution lifecycle of different import backends.
// It allows TiDB Lightning to support multiple import engines (such as the standard Local/Importer
// backend and the specialized IMPORT INTO procedure) through a unified API.
//
// This interface should be implemented when adding a new import strategy or engine that requires
// integration with the Lightning task server. Unlike the legacy Controller pattern, which is often
// specifically tied to mydump-based data sourcing and global resource management, LightningImporter
// provides a polymorphic abstraction focused on the execution phase (Run, Pause, Resume, Close),
// enabling the higher-level server to manage various import procedures consistently.
type LightningImporter interface {
// Run starts the import process.
Run(ctx context.Context) error
// Pause pauses the import process.
Pause(ctx context.Context) error
// Resume resumes the import process.
Resume(ctx context.Context) error
// Close closes the importer and releases resources.
Close()
}
// newImporter creates a new LightningImporter based on the configuration.
func newImporter(ctx context.Context, cfg *config.Config, param *importer.ControllerParam) (LightningImporter, error) {
switch cfg.TikvImporter.Backend {
case config.BackendImportInto:
return importinto.NewImporter(ctx, cfg, param.DB)
case config.BackendLocal, config.BackendTiDB:
return importer.NewImportController(ctx, cfg, param)
default:
return nil, errors.Errorf("unknown backend %s", cfg.TikvImporter.Backend)
}
}