diff --git a/br/pkg/restore/client.go b/br/pkg/restore/client.go index e46c55d7ae..7701248777 100644 --- a/br/pkg/restore/client.go +++ b/br/pkg/restore/client.go @@ -852,13 +852,38 @@ func (rc *Client) GetDBSchema(dom *domain.Domain, dbName model.CIStr) (*model.DB return info.SchemaByName(dbName) } -// CreateDatabase creates a database. -func (rc *Client) CreateDatabase(ctx context.Context, db *model.DBInfo) error { +// CreateDatabases creates databases. If the client has the db pool, it would create it. +func (rc *Client) CreateDatabases(ctx context.Context, dbs []*metautil.Database) error { if rc.IsSkipCreateSQL() { - log.Info("skip create database", zap.Stringer("name", db.Name)) + log.Info("skip create database") return nil } + if len(rc.dbPool) == 0 { + log.Info("create databases sequentially") + for _, db := range dbs { + err := rc.createDatabaseWithDBConn(ctx, db.Info, rc.db) + if err != nil { + return errors.Trace(err) + } + } + return nil + } + + log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool))) + eg, ectx := errgroup.WithContext(ctx) + workers := utils.NewWorkerPool(uint(len(rc.dbPool)), "DB DDL workers") + for _, db_ := range dbs { + db := db_ + workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { + conn := rc.dbPool[id%uint64(len(rc.dbPool))] + return rc.createDatabaseWithDBConn(ectx, db.Info, conn) + }) + } + return eg.Wait() +} + +func (rc *Client) createDatabaseWithDBConn(ctx context.Context, db *model.DBInfo, conn *DB) error { log.Info("create database", zap.Stringer("name", db.Name)) if !rc.supportPolicy { @@ -868,12 +893,12 @@ func (rc *Client) CreateDatabase(ctx context.Context, db *model.DBInfo) error { } if db.PlacementPolicyRef != nil { - if err := rc.db.ensurePlacementPolicy(ctx, db.PlacementPolicyRef.Name, rc.policyMap); err != nil { + if err := conn.ensurePlacementPolicy(ctx, db.PlacementPolicyRef.Name, rc.policyMap); err != nil { return errors.Trace(err) } } - return rc.db.CreateDatabase(ctx, db) + return conn.CreateDatabase(ctx, db) } // CreateTables creates multiple tables, and returns their rewrite rules. diff --git a/br/pkg/restore/client_test.go b/br/pkg/restore/client_test.go index 4407a90d07..c291440e9b 100644 --- a/br/pkg/restore/client_test.go +++ b/br/pkg/restore/client_test.go @@ -137,7 +137,7 @@ func TestCheckTargetClusterFresh(t *testing.T) { ctx := context.Background() require.NoError(t, client.CheckTargetClusterFresh(ctx)) - require.NoError(t, client.CreateDatabase(ctx, &model.DBInfo{Name: model.NewCIStr("user_db")})) + require.NoError(t, client.CreateDatabases(ctx, []*metautil.Database{{Info: &model.DBInfo{Name: model.NewCIStr("user_db")}}})) require.True(t, berrors.ErrRestoreNotFreshCluster.Equal(client.CheckTargetClusterFresh(ctx))) } diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 249bdda7ec..10b2d1a1b6 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -930,11 +930,8 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf return nil } - for _, db := range dbs { - err = client.CreateDatabase(ctx, db.Info) - if err != nil { - return errors.Trace(err) - } + if err = client.CreateDatabases(ctx, dbs); err != nil { + return errors.Trace(err) } // We make bigger errCh so we won't block on multi-part failed.