storage: support parallel write for gcs (#49545)
close pingcap/tidb#48443
This commit is contained in:
@ -33,6 +33,7 @@ go_library(
|
||||
"//br/pkg/storage",
|
||||
"//pkg/kv",
|
||||
"//pkg/metrics",
|
||||
"//pkg/util",
|
||||
"//pkg/util/hack",
|
||||
"//pkg/util/logutil",
|
||||
"//pkg/util/size",
|
||||
|
||||
@ -209,6 +209,9 @@ func (r *byteReader) readNBytes(n int) ([]byte, error) {
|
||||
return bs[0], nil
|
||||
}
|
||||
// need to flatten bs
|
||||
if n <= 0 {
|
||||
return nil, errors.Errorf("illegal n (%d) when reading from external storage", n)
|
||||
}
|
||||
if n > int(size.GB) {
|
||||
return nil, errors.Errorf("read %d bytes from external storage, exceed max limit %d", n, size.GB)
|
||||
}
|
||||
|
||||
11
br/pkg/lightning/backend/external/reader.go
vendored
11
br/pkg/lightning/backend/external/reader.go
vendored
@ -20,13 +20,14 @@ import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/tidb/br/pkg/lightning/log"
|
||||
"github.com/pingcap/tidb/br/pkg/membuf"
|
||||
"github.com/pingcap/tidb/br/pkg/storage"
|
||||
"github.com/pingcap/tidb/pkg/metrics"
|
||||
"github.com/pingcap/tidb/pkg/util"
|
||||
"github.com/pingcap/tidb/pkg/util/logutil"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func readAllData(
|
||||
@ -65,12 +66,13 @@ func readAllData(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var eg errgroup.Group
|
||||
eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx)
|
||||
// TODO(lance6716): limit the concurrency of eg to 30 does not help
|
||||
for i := range dataFiles {
|
||||
i := i
|
||||
eg.Go(func() error {
|
||||
return readOneFile(
|
||||
ctx,
|
||||
err2 := readOneFile(
|
||||
egCtx,
|
||||
storage,
|
||||
dataFiles[i],
|
||||
startKey,
|
||||
@ -80,6 +82,7 @@ func readAllData(
|
||||
bufPool,
|
||||
output,
|
||||
)
|
||||
return errors.Annotatef(err2, "failed to read file %s", dataFiles[i])
|
||||
})
|
||||
}
|
||||
return eg.Wait()
|
||||
|
||||
163
br/pkg/lightning/backend/external/writer.go
vendored
163
br/pkg/lightning/backend/external/writer.go
vendored
@ -197,7 +197,6 @@ func (b *WriterBuilder) Build(
|
||||
filenamePrefix: filenamePrefix,
|
||||
keyAdapter: keyAdapter,
|
||||
writerID: writerID,
|
||||
kvStore: nil,
|
||||
onClose: b.onClose,
|
||||
closed: false,
|
||||
multiFileStats: make([]MultipleFilesStat, 1),
|
||||
@ -293,8 +292,7 @@ type Writer struct {
|
||||
filenamePrefix string
|
||||
keyAdapter common.KeyAdapter
|
||||
|
||||
kvStore *KeyValueStore
|
||||
rc *rangePropertiesCollector
|
||||
rc *rangePropertiesCollector
|
||||
|
||||
memSizeLimit uint64
|
||||
|
||||
@ -400,88 +398,53 @@ func (w *Writer) recordMinMax(newMin, newMax tidbkv.Key, size uint64) {
|
||||
w.totalSize += size
|
||||
}
|
||||
|
||||
const flushKVsRetryTimes = 3
|
||||
|
||||
func (w *Writer) flushKVs(ctx context.Context, fromClose bool) (err error) {
|
||||
if len(w.kvLocations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger := logutil.Logger(ctx)
|
||||
dataFile, statFile, dataWriter, statWriter, err := w.createStorageWriter(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
savedBytes uint64
|
||||
statSize int
|
||||
sortDuration, writeDuration time.Duration
|
||||
writeStartTime time.Time
|
||||
logger := logutil.Logger(ctx).With(
|
||||
zap.String("writer-id", w.writerID),
|
||||
zap.Int("sequence-number", w.currentSeq),
|
||||
)
|
||||
savedBytes = w.batchSize
|
||||
startTs := time.Now()
|
||||
|
||||
kvCnt := len(w.kvLocations)
|
||||
defer func() {
|
||||
w.currentSeq++
|
||||
err1, err2 := dataWriter.Close(ctx), statWriter.Close(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err1 != nil {
|
||||
logger.Error("close data writer failed", zap.Error(err1))
|
||||
err = err1
|
||||
return
|
||||
}
|
||||
if err2 != nil {
|
||||
logger.Error("close stat writer failed", zap.Error(err2))
|
||||
err = err2
|
||||
return
|
||||
}
|
||||
writeDuration = time.Since(writeStartTime)
|
||||
logger.Info("flush kv",
|
||||
zap.Uint64("bytes", savedBytes),
|
||||
zap.Int("kv-cnt", kvCnt),
|
||||
zap.Int("stat-size", statSize),
|
||||
zap.Duration("sort-time", sortDuration),
|
||||
zap.Duration("write-time", writeDuration),
|
||||
zap.String("sort-speed(kv/s)", getSpeed(uint64(kvCnt), sortDuration.Seconds(), false)),
|
||||
zap.String("write-speed(bytes/s)", getSpeed(savedBytes, writeDuration.Seconds(), true)),
|
||||
zap.String("writer-id", w.writerID),
|
||||
)
|
||||
metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("write").Observe(writeDuration.Seconds())
|
||||
metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("write").Observe(float64(savedBytes) / 1024.0 / 1024.0 / writeDuration.Seconds())
|
||||
metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort_and_write").Observe(time.Since(startTs).Seconds())
|
||||
metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort_and_write").Observe(float64(savedBytes) / 1024.0 / 1024.0 / time.Since(startTs).Seconds())
|
||||
}()
|
||||
|
||||
sortStart := time.Now()
|
||||
slices.SortFunc(w.kvLocations, func(i, j membuf.SliceLocation) int {
|
||||
return bytes.Compare(w.getKeyByLoc(i), w.getKeyByLoc(j))
|
||||
})
|
||||
sortDuration = time.Since(sortStart)
|
||||
|
||||
writeStartTime = time.Now()
|
||||
sortDuration := time.Since(sortStart)
|
||||
metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort").Observe(sortDuration.Seconds())
|
||||
metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort").Observe(float64(savedBytes) / 1024.0 / 1024.0 / sortDuration.Seconds())
|
||||
w.kvStore, err = NewKeyValueStore(ctx, dataWriter, w.rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / sortDuration.Seconds())
|
||||
|
||||
for _, pair := range w.kvLocations {
|
||||
err = w.kvStore.addEncodedData(w.kvBuffer.GetSlice(pair))
|
||||
if err != nil {
|
||||
return err
|
||||
writeStartTime := time.Now()
|
||||
var dataFile, statFile string
|
||||
for i := 0; i < flushKVsRetryTimes; i++ {
|
||||
dataFile, statFile, err = w.flushSortedKVs(ctx)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
logger.Warn("flush sorted kv failed",
|
||||
zap.Error(err),
|
||||
zap.Int("retry-count", i),
|
||||
)
|
||||
}
|
||||
|
||||
w.kvStore.Close()
|
||||
encodedStat := w.rc.encode()
|
||||
statSize = len(encodedStat)
|
||||
_, err = statWriter.Write(ctx, encodedStat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
writeDuration := time.Since(writeStartTime)
|
||||
kvCnt := len(w.kvLocations)
|
||||
logger.Info("flush kv",
|
||||
zap.Uint64("bytes", w.batchSize),
|
||||
zap.Int("kv-cnt", kvCnt),
|
||||
zap.Duration("sort-time", sortDuration),
|
||||
zap.Duration("write-time", writeDuration),
|
||||
zap.String("sort-speed(kv/s)", getSpeed(uint64(kvCnt), sortDuration.Seconds(), false)),
|
||||
zap.String("writer-id", w.writerID),
|
||||
)
|
||||
totalDuration := time.Since(sortStart)
|
||||
metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort_and_write").Observe(totalDuration.Seconds())
|
||||
metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort_and_write").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / totalDuration.Seconds())
|
||||
|
||||
minKey, maxKey := w.getKeyByLoc(w.kvLocations[0]), w.getKeyByLoc(w.kvLocations[len(w.kvLocations)-1])
|
||||
w.recordMinMax(minKey, maxKey, uint64(w.kvSize))
|
||||
@ -507,9 +470,73 @@ func (w *Writer) flushKVs(ctx context.Context, fromClose bool) (err error) {
|
||||
w.kvBuffer.Reset()
|
||||
w.rc.reset()
|
||||
w.batchSize = 0
|
||||
w.currentSeq++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Writer) flushSortedKVs(ctx context.Context) (string, string, error) {
|
||||
logger := logutil.Logger(ctx).With(
|
||||
zap.String("writer-id", w.writerID),
|
||||
zap.Int("sequence-number", w.currentSeq),
|
||||
)
|
||||
writeStartTime := time.Now()
|
||||
dataFile, statFile, dataWriter, statWriter, err := w.createStorageWriter(ctx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer func() {
|
||||
// close the writers when meet error. If no error happens, writers will
|
||||
// be closed outside and assigned to nil.
|
||||
if dataWriter != nil {
|
||||
_ = dataWriter.Close(ctx)
|
||||
}
|
||||
if statWriter != nil {
|
||||
_ = statWriter.Close(ctx)
|
||||
}
|
||||
}()
|
||||
kvStore, err := NewKeyValueStore(ctx, dataWriter, w.rc)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
for _, pair := range w.kvLocations {
|
||||
err = kvStore.addEncodedData(w.kvBuffer.GetSlice(pair))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
kvStore.Close()
|
||||
encodedStat := w.rc.encode()
|
||||
statSize := len(encodedStat)
|
||||
_, err = statWriter.Write(ctx, encodedStat)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
err = dataWriter.Close(ctx)
|
||||
dataWriter = nil
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
err = statWriter.Close(ctx)
|
||||
statWriter = nil
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
writeDuration := time.Since(writeStartTime)
|
||||
logger.Info("flush sorted kv",
|
||||
zap.Uint64("bytes", w.batchSize),
|
||||
zap.Int("stat-size", statSize),
|
||||
zap.Duration("write-time", writeDuration),
|
||||
zap.String("write-speed(bytes/s)", getSpeed(w.batchSize, writeDuration.Seconds(), true)),
|
||||
)
|
||||
metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("write").Observe(writeDuration.Seconds())
|
||||
metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("write").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / writeDuration.Seconds())
|
||||
|
||||
return dataFile, statFile, nil
|
||||
}
|
||||
|
||||
func (w *Writer) getKeyByLoc(loc membuf.SliceLocation) []byte {
|
||||
block := w.kvBuffer.GetSlice(loc)
|
||||
keyLen := binary.BigEndian.Uint64(block[:lengthBytes])
|
||||
|
||||
@ -7,6 +7,7 @@ go_library(
|
||||
"compress.go",
|
||||
"flags.go",
|
||||
"gcs.go",
|
||||
"gcs_extra.go",
|
||||
"hdfs.go",
|
||||
"helper.go",
|
||||
"ks3.go",
|
||||
@ -49,6 +50,7 @@ go_library(
|
||||
"@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//bloberror",
|
||||
"@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//blockblob",
|
||||
"@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//container",
|
||||
"@com_github_go_resty_resty_v2//:resty",
|
||||
"@com_github_google_uuid//:uuid",
|
||||
"@com_github_klauspost_compress//gzip",
|
||||
"@com_github_klauspost_compress//snappy",
|
||||
|
||||
@ -99,6 +99,7 @@ func (options *GCSBackendOptions) parseFromFlags(flags *pflag.FlagSet) error {
|
||||
type GCSStorage struct {
|
||||
gcs *backuppb.GCS
|
||||
bucket *storage.BucketHandle
|
||||
cli *storage.Client
|
||||
}
|
||||
|
||||
// GetBucketHandle gets the handle to the GCS API on the bucket.
|
||||
@ -272,12 +273,29 @@ func (s *GCSStorage) URI() string {
|
||||
}
|
||||
|
||||
// Create implements ExternalStorage interface.
|
||||
func (s *GCSStorage) Create(ctx context.Context, name string, _ *WriterOption) (ExternalFileWriter, error) {
|
||||
object := s.objectName(name)
|
||||
wc := s.bucket.Object(object).NewWriter(ctx)
|
||||
wc.StorageClass = s.gcs.StorageClass
|
||||
wc.PredefinedACL = s.gcs.PredefinedAcl
|
||||
return newFlushStorageWriter(wc, &emptyFlusher{}, wc), nil
|
||||
func (s *GCSStorage) Create(ctx context.Context, name string, wo *WriterOption) (ExternalFileWriter, error) {
|
||||
// NewGCSWriter requires real testing environment on Google Cloud.
|
||||
mockGCS := intest.InTest && strings.Contains(s.gcs.GetEndpoint(), "127.0.0.1")
|
||||
if wo == nil || wo.Concurrency <= 1 || mockGCS {
|
||||
object := s.objectName(name)
|
||||
wc := s.bucket.Object(object).NewWriter(ctx)
|
||||
wc.StorageClass = s.gcs.StorageClass
|
||||
wc.PredefinedACL = s.gcs.PredefinedAcl
|
||||
return newFlushStorageWriter(wc, &emptyFlusher{}, wc), nil
|
||||
}
|
||||
uri := s.objectName(name)
|
||||
// 5MB is the minimum part size for GCS.
|
||||
partSize := int64(gcsMinimumChunkSize)
|
||||
if wo.PartSize > partSize {
|
||||
partSize = wo.PartSize
|
||||
}
|
||||
w, err := NewGCSWriter(ctx, s.cli, uri, partSize, wo.Concurrency, s.gcs.Bucket)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
fw := newFlushStorageWriter(w, &emptyFlusher{}, w)
|
||||
bw := newBufferedWriter(fw, int(partSize), NoCompression)
|
||||
return bw, nil
|
||||
}
|
||||
|
||||
// Rename file name from oldFileName to newFileName.
|
||||
@ -371,7 +389,7 @@ skipHandleCred:
|
||||
// so we need find sst in slash directory
|
||||
gcs.Prefix += "//"
|
||||
}
|
||||
return &GCSStorage{gcs: gcs, bucket: bucket}, nil
|
||||
return &GCSStorage{gcs: gcs, bucket: bucket, cli: client}, nil
|
||||
}
|
||||
|
||||
func hasSSTFiles(ctx context.Context, bucket *storage.BucketHandle, prefix string) bool {
|
||||
|
||||
419
br/pkg/storage/gcs_extra.go
Normal file
419
br/pkg/storage/gcs_extra.go
Normal file
@ -0,0 +1,419 @@
|
||||
// Copyright 2023 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Learned from https://github.com/liqiuqing/gcsmpu
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/storage"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// GCSWriter uses XML multipart upload API to upload a single file.
|
||||
// https://cloud.google.com/storage/docs/multipart-uploads.
|
||||
// GCSWriter will attempt to cancel uploads that fail due to an exception.
|
||||
// If the upload fails in a way that precludes cancellation, such as a
|
||||
// hardware failure, process termination, or power outage, then the incomplete
|
||||
// upload may persist indefinitely. To mitigate this, set the
|
||||
// `AbortIncompleteMultipartUpload` with a nonzero `Age` in bucket lifecycle
|
||||
// rules, or refer to the XML API documentation linked above to learn more
|
||||
// about how to list and delete individual downloads.
|
||||
type GCSWriter struct {
|
||||
uploadBase
|
||||
mutex sync.Mutex
|
||||
xmlMPUParts []*xmlMPUPart
|
||||
wg sync.WaitGroup
|
||||
err atomic.Error
|
||||
chunkSize int64
|
||||
workers int
|
||||
totalSize int64
|
||||
uploadID string
|
||||
chunkCh chan chunk
|
||||
curPart int
|
||||
}
|
||||
|
||||
// NewGCSWriter returns a GCSWriter which uses GCS multipart upload API behind the scene.
|
||||
func NewGCSWriter(
|
||||
ctx context.Context,
|
||||
cli *storage.Client,
|
||||
uri string,
|
||||
partSize int64,
|
||||
parallelCnt int,
|
||||
bucketName string,
|
||||
) (*GCSWriter, error) {
|
||||
if partSize < gcsMinimumChunkSize || partSize > gcsMaximumChunkSize {
|
||||
return nil, fmt.Errorf(
|
||||
"invalid chunk size: %d. Chunk size must be between %d and %d",
|
||||
partSize, gcsMinimumChunkSize, gcsMaximumChunkSize,
|
||||
)
|
||||
}
|
||||
|
||||
w := &GCSWriter{
|
||||
uploadBase: uploadBase{
|
||||
ctx: ctx,
|
||||
cli: cli,
|
||||
bucket: bucketName,
|
||||
blob: uri,
|
||||
retry: defaultRetry,
|
||||
signedURLExpiry: defaultSignedURLExpiry,
|
||||
},
|
||||
chunkSize: partSize,
|
||||
workers: parallelCnt,
|
||||
}
|
||||
if err := w.init(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initiate GCSWriter: %w", err)
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *GCSWriter) init() error {
|
||||
opts := &storage.SignedURLOptions{
|
||||
Scheme: storage.SigningSchemeV4,
|
||||
Method: "POST",
|
||||
Expires: time.Now().Add(w.signedURLExpiry),
|
||||
QueryParameters: url.Values{mpuInitiateQuery: []string{""}},
|
||||
}
|
||||
u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err)
|
||||
}
|
||||
|
||||
client := resty.New()
|
||||
resp, err := client.R().Post(u)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST request failed: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("POST request returned non-OK status: %d", resp.StatusCode())
|
||||
}
|
||||
body := resp.Body()
|
||||
|
||||
result := InitiateMultipartUploadResult{}
|
||||
err = xml.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response body: %s", err)
|
||||
}
|
||||
|
||||
uploadID := result.UploadId
|
||||
w.uploadID = uploadID
|
||||
w.chunkCh = make(chan chunk)
|
||||
for i := 0; i < w.workers; i++ {
|
||||
w.wg.Add(1)
|
||||
go w.readChunk(w.chunkCh)
|
||||
}
|
||||
w.curPart = 1
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *GCSWriter) readChunk(ch chan chunk) {
|
||||
defer w.wg.Done()
|
||||
for {
|
||||
data, ok := <-ch
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
data.cleanup()
|
||||
return
|
||||
default:
|
||||
part := &xmlMPUPart{
|
||||
uploadBase: w.uploadBase,
|
||||
uploadID: w.uploadID,
|
||||
buf: data.buf,
|
||||
partNumber: data.num,
|
||||
}
|
||||
if w.err.Load() == nil {
|
||||
if err := part.Upload(); err != nil {
|
||||
w.err.Store(err)
|
||||
}
|
||||
}
|
||||
part.buf = nil
|
||||
w.appendMPUPart(part)
|
||||
data.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write uploads given bytes as a part to Google Cloud Storage. Write is not
|
||||
// concurrent safe.
|
||||
func (w *GCSWriter) Write(p []byte) (n int, err error) {
|
||||
if w.curPart > gcsMaximumParts {
|
||||
err = fmt.Errorf("exceed maximum parts %d", gcsMaximumParts)
|
||||
if w.err.Load() == nil {
|
||||
w.err.Store(err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
buf := make([]byte, len(p))
|
||||
copy(buf, p)
|
||||
w.chunkCh <- chunk{
|
||||
buf: buf,
|
||||
num: w.curPart,
|
||||
cleanup: func() {},
|
||||
}
|
||||
w.curPart++
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close finishes the upload.
|
||||
func (w *GCSWriter) Close() error {
|
||||
close(w.chunkCh)
|
||||
w.wg.Wait()
|
||||
|
||||
if err := w.err.Load(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := w.finalizeXMLMPU()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
errC := w.cancel()
|
||||
if errC != nil {
|
||||
return fmt.Errorf("failed to finalize multipart upload: %s, Failed to cancel multipart upload: %s", err, errC)
|
||||
}
|
||||
return fmt.Errorf("failed to finalize multipart upload: %s", err)
|
||||
}
|
||||
|
||||
const (
|
||||
mpuInitiateQuery = "uploads"
|
||||
mpuPartNumberQuery = "partNumber"
|
||||
mpuUploadIDQuery = "uploadId"
|
||||
)
|
||||
|
||||
type uploadBase struct {
|
||||
cli *storage.Client
|
||||
ctx context.Context
|
||||
bucket string
|
||||
blob string
|
||||
retry int
|
||||
signedURLExpiry time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
defaultRetry = 3
|
||||
defaultSignedURLExpiry = 6 * time.Hour
|
||||
|
||||
gcsMinimumChunkSize = 5 * 1024 * 1024 // 5 MB
|
||||
gcsMaximumChunkSize = 5 * 1024 * 1024 * 1024 // 5 GB
|
||||
gcsMaximumParts = 10000
|
||||
)
|
||||
|
||||
type InitiateMultipartUploadResult struct {
|
||||
XMLName xml.Name `xml:"InitiateMultipartUploadResult"`
|
||||
Text string `xml:",chardata"`
|
||||
Xmlns string `xml:"xmlns,attr"`
|
||||
Bucket string `xml:"Bucket"`
|
||||
Key string `xml:"Key"`
|
||||
UploadId string `xml:"UploadId"`
|
||||
}
|
||||
|
||||
type Part struct {
|
||||
Text string `xml:",chardata"`
|
||||
PartNumber int `xml:"PartNumber"`
|
||||
ETag string `xml:"ETag"`
|
||||
}
|
||||
|
||||
type CompleteMultipartUpload struct {
|
||||
XMLName xml.Name `xml:"CompleteMultipartUpload"`
|
||||
Text string `xml:",chardata"`
|
||||
Parts []Part `xml:"Part"`
|
||||
}
|
||||
|
||||
func (w *GCSWriter) finalizeXMLMPU() error {
|
||||
finalXMLRoot := CompleteMultipartUpload{
|
||||
Parts: make([]Part, 0, len(w.xmlMPUParts)),
|
||||
}
|
||||
slices.SortFunc(w.xmlMPUParts, func(a, b *xmlMPUPart) int {
|
||||
return a.partNumber - b.partNumber
|
||||
})
|
||||
for _, part := range w.xmlMPUParts {
|
||||
part := Part{
|
||||
PartNumber: part.partNumber,
|
||||
ETag: part.etag,
|
||||
}
|
||||
finalXMLRoot.Parts = append(finalXMLRoot.Parts, part)
|
||||
}
|
||||
|
||||
xmlBytes, err := xml.Marshal(finalXMLRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode XML: %v", err)
|
||||
}
|
||||
|
||||
opts := &storage.SignedURLOptions{
|
||||
Scheme: storage.SigningSchemeV4,
|
||||
Method: "POST",
|
||||
Expires: time.Now().Add(w.signedURLExpiry),
|
||||
QueryParameters: url.Values{mpuUploadIDQuery: []string{w.uploadID}},
|
||||
}
|
||||
u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err)
|
||||
}
|
||||
|
||||
client := resty.New()
|
||||
resp, err := client.R().SetBody(xmlBytes).Post(u)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST request failed: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
return fmt.Errorf("POST request returned non-OK status: %d, body: %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
buf []byte
|
||||
num int
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func (w *GCSWriter) appendMPUPart(part *xmlMPUPart) {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
w.xmlMPUParts = append(w.xmlMPUParts, part)
|
||||
}
|
||||
|
||||
func (w *GCSWriter) cancel() error {
|
||||
opts := &storage.SignedURLOptions{
|
||||
Scheme: storage.SigningSchemeV4,
|
||||
Method: "DELETE",
|
||||
Expires: time.Now().Add(w.signedURLExpiry),
|
||||
QueryParameters: url.Values{mpuUploadIDQuery: []string{w.uploadID}},
|
||||
}
|
||||
u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err)
|
||||
}
|
||||
|
||||
client := resty.New()
|
||||
resp, err := client.R().Delete(u)
|
||||
if err != nil {
|
||||
return fmt.Errorf("DELETE request failed: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode() != http.StatusNoContent {
|
||||
return fmt.Errorf("DELETE request returned non-204 status: %d", resp.StatusCode())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type xmlMPUPart struct {
|
||||
uploadBase
|
||||
buf []byte
|
||||
uploadID string
|
||||
partNumber int
|
||||
etag string
|
||||
}
|
||||
|
||||
func (p *xmlMPUPart) Clone() *xmlMPUPart {
|
||||
return &xmlMPUPart{
|
||||
uploadBase: p.uploadBase,
|
||||
uploadID: p.uploadID,
|
||||
buf: p.buf,
|
||||
partNumber: p.partNumber,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *xmlMPUPart) Upload() error {
|
||||
var err error
|
||||
for i := 0; i < p.retry; i++ {
|
||||
err = p.upload()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to upload part %d: %w", p.partNumber, err)
|
||||
}
|
||||
|
||||
func (p *xmlMPUPart) upload() error {
|
||||
opts := &storage.SignedURLOptions{
|
||||
Scheme: storage.SigningSchemeV4,
|
||||
Method: "PUT",
|
||||
Expires: time.Now().Add(p.signedURLExpiry),
|
||||
QueryParameters: url.Values{
|
||||
mpuUploadIDQuery: []string{p.uploadID},
|
||||
mpuPartNumberQuery: []string{strconv.Itoa(p.partNumber)},
|
||||
},
|
||||
}
|
||||
|
||||
u, err := p.cli.Bucket(p.bucket).SignedURL(p.blob, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Bucket(%q).SignedURL: %s", p.bucket, err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("PUT", u, bytes.NewReader(p.buf))
|
||||
if err != nil {
|
||||
return fmt.Errorf("PUT request failed: %s", err)
|
||||
}
|
||||
req = req.WithContext(p.ctx)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: createTransport(nil),
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("PUT request failed: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
p.etag = resp.Header.Get("ETag")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("PUT request returned non-OK status: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createTransport(localAddr net.Addr) *http.Transport {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
if localAddr != nil {
|
||||
dialer.LocalAddr = localAddr
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1,
|
||||
}
|
||||
}
|
||||
@ -3,7 +3,10 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@ -460,3 +463,39 @@ func TestReadRange(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("234"), content[:n])
|
||||
}
|
||||
|
||||
var testingStorageURI = flag.String("testing-storage-uri", "", "the URI of the storage used for testing")
|
||||
|
||||
func openTestingStorage(t *testing.T) ExternalStorage {
|
||||
if *testingStorageURI == "" {
|
||||
t.Skip("testingStorageURI is not set")
|
||||
}
|
||||
s, err := NewFromURL(context.Background(), *testingStorageURI)
|
||||
require.NoError(t, err)
|
||||
return s
|
||||
}
|
||||
|
||||
func TestMultiPartUpload(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
s := openTestingStorage(t)
|
||||
if _, ok := s.(*GCSStorage); !ok {
|
||||
t.Skipf("only test GCSStorage, got %T", s)
|
||||
}
|
||||
|
||||
filename := "TestMultiPartUpload"
|
||||
// just get some random content, use any seed is enough
|
||||
data := make([]byte, 100*1024*1024)
|
||||
rand.Read(data)
|
||||
w, err := s.Create(ctx, filename, &WriterOption{Concurrency: 10})
|
||||
require.NoError(t, err)
|
||||
_, err = w.Write(ctx, data)
|
||||
require.NoError(t, err)
|
||||
err = w.Close(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.ReadFile(ctx, filename)
|
||||
require.NoError(t, err)
|
||||
cmp := bytes.Compare(data, got)
|
||||
require.Zero(t, cmp)
|
||||
}
|
||||
|
||||
@ -147,6 +147,12 @@ func TestCreateStorage(t *testing.T) {
|
||||
require.Equal(t, "https://gcs.example.com/", gcs.Endpoint)
|
||||
require.Equal(t, "fakeCredentials", gcs.CredentialsBlob)
|
||||
|
||||
s, err = ParseBackend("gcs://bucket?endpoint=http://127.0.0.1/", gcsOpt)
|
||||
require.NoError(t, err)
|
||||
gcs = s.GetGcs()
|
||||
require.NotNil(t, gcs)
|
||||
require.Equal(t, "http://127.0.0.1/", gcs.Endpoint)
|
||||
|
||||
err = os.WriteFile(fakeCredentialsFile, []byte("fakeCreds2"), credFilePerm)
|
||||
require.NoError(t, err)
|
||||
s, err = ParseBackend("gs://bucket4/backup/?credentials-file="+url.QueryEscape(fakeCredentialsFile), nil)
|
||||
|
||||
1
go.mod
1
go.mod
@ -41,6 +41,7 @@ require (
|
||||
github.com/fatih/color v1.15.0
|
||||
github.com/fsouza/fake-gcs-server v1.44.0
|
||||
github.com/go-ldap/ldap/v3 v3.4.4
|
||||
github.com/go-resty/resty/v2 v2.7.0
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/golang/protobuf v1.5.3
|
||||
|
||||
3
go.sum
3
go.sum
@ -295,6 +295,8 @@ github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AE
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
|
||||
github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY=
|
||||
github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I=
|
||||
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
|
||||
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
@ -1095,6 +1097,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210427231257-85d9c07bbe3a/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
||||
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
golang.org/x/net v0.0.0-20220517181318-183a9ca12b87/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
|
||||
Reference in New Issue
Block a user