431 lines
10 KiB
Go
431 lines
10 KiB
Go
// Copyright 2023 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Learned from https://github.com/liqiuqing/gcsmpu
|
|
|
|
package objstore
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"runtime"
|
|
"slices"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"cloud.google.com/go/storage"
|
|
"github.com/go-resty/resty/v2"
|
|
"go.uber.org/atomic"
|
|
)
|
|
|
|
// GCSWriter uses XML multipart upload API to upload a single file.
|
|
// https://cloud.google.com/storage/docs/multipart-uploads.
|
|
// GCSWriter will attempt to cancel uploads that fail due to an exception.
|
|
// If the upload fails in a way that precludes cancellation, such as a
|
|
// hardware failure, process termination, or power outage, then the incomplete
|
|
// upload may persist indefinitely. To mitigate this, set the
|
|
// `AbortIncompleteMultipartUpload` with a nonzero `Age` in bucket lifecycle
|
|
// rules, or refer to the XML API documentation linked above to learn more
|
|
// about how to list and delete individual downloads.
|
|
type GCSWriter struct {
|
|
uploadBase
|
|
mutex sync.Mutex
|
|
xmlMPUParts []*xmlMPUPart
|
|
wg sync.WaitGroup
|
|
err atomic.Error
|
|
chunkSize int64
|
|
workers int
|
|
totalSize int64
|
|
uploadID string
|
|
chunkCh chan chunk
|
|
curPart int
|
|
}
|
|
|
|
// NewGCSWriter returns a GCSWriter which uses GCS multipart upload API behind the scene.
|
|
func NewGCSWriter(
|
|
ctx context.Context,
|
|
cli *storage.Client,
|
|
uri string,
|
|
partSize int64,
|
|
parallelCnt int,
|
|
bucketName string,
|
|
) (*GCSWriter, error) {
|
|
if partSize < gcsMinimumChunkSize || partSize > gcsMaximumChunkSize {
|
|
return nil, fmt.Errorf(
|
|
"invalid chunk size: %d. Chunk size must be between %d and %d",
|
|
partSize, gcsMinimumChunkSize, gcsMaximumChunkSize,
|
|
)
|
|
}
|
|
|
|
w := &GCSWriter{
|
|
uploadBase: uploadBase{
|
|
ctx: ctx,
|
|
cli: cli,
|
|
bucket: bucketName,
|
|
blob: uri,
|
|
retry: defaultRetry,
|
|
signedURLExpiry: defaultSignedURLExpiry,
|
|
},
|
|
chunkSize: partSize,
|
|
workers: parallelCnt,
|
|
}
|
|
if err := w.init(); err != nil {
|
|
return nil, fmt.Errorf("failed to initiate GCSWriter: %w", err)
|
|
}
|
|
|
|
return w, nil
|
|
}
|
|
|
|
func (w *GCSWriter) init() error {
|
|
opts := &storage.SignedURLOptions{
|
|
Scheme: storage.SigningSchemeV4,
|
|
Method: "POST",
|
|
Expires: time.Now().Add(w.signedURLExpiry),
|
|
QueryParameters: url.Values{mpuInitiateQuery: []string{""}},
|
|
}
|
|
u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts)
|
|
if err != nil {
|
|
return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err)
|
|
}
|
|
|
|
client := resty.New()
|
|
resp, err := client.R().Post(u)
|
|
if err != nil {
|
|
return fmt.Errorf("POST request failed: %s", err)
|
|
}
|
|
|
|
if resp.StatusCode() != http.StatusOK {
|
|
return fmt.Errorf("POST request returned non-OK status: %d", resp.StatusCode())
|
|
}
|
|
body := resp.Body()
|
|
|
|
result := InitiateMultipartUploadResult{}
|
|
err = xml.Unmarshal(body, &result)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to unmarshal response body: %s", err)
|
|
}
|
|
|
|
uploadID := result.UploadID
|
|
w.uploadID = uploadID
|
|
w.chunkCh = make(chan chunk)
|
|
for range w.workers {
|
|
w.wg.Add(1)
|
|
go w.readChunk(w.chunkCh)
|
|
}
|
|
w.curPart = 1
|
|
return nil
|
|
}
|
|
|
|
func (w *GCSWriter) readChunk(ch chan chunk) {
|
|
defer w.wg.Done()
|
|
for {
|
|
data, ok := <-ch
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
func() {
|
|
activeUploadWorkerCnt.Add(1)
|
|
defer activeUploadWorkerCnt.Add(-1)
|
|
|
|
select {
|
|
case <-w.ctx.Done():
|
|
data.cleanup()
|
|
w.err.CompareAndSwap(nil, w.ctx.Err())
|
|
default:
|
|
part := &xmlMPUPart{
|
|
uploadBase: w.uploadBase,
|
|
uploadID: w.uploadID,
|
|
buf: data.buf,
|
|
partNumber: data.num,
|
|
}
|
|
if w.err.Load() == nil {
|
|
if err := part.Upload(); err != nil {
|
|
w.err.Store(err)
|
|
}
|
|
}
|
|
part.buf = nil
|
|
w.appendMPUPart(part)
|
|
data.cleanup()
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
// Write uploads given bytes as a part to Google Cloud Storage. Write is not
|
|
// concurrent safe.
|
|
func (w *GCSWriter) Write(p []byte) (n int, err error) {
|
|
if w.curPart > gcsMaximumParts {
|
|
err = fmt.Errorf("exceed maximum parts %d", gcsMaximumParts)
|
|
if w.err.Load() == nil {
|
|
w.err.Store(err)
|
|
}
|
|
return 0, err
|
|
}
|
|
buf := make([]byte, len(p))
|
|
copy(buf, p)
|
|
w.chunkCh <- chunk{
|
|
buf: buf,
|
|
num: w.curPart,
|
|
cleanup: func() {},
|
|
}
|
|
w.curPart++
|
|
return len(p), nil
|
|
}
|
|
|
|
// Close finishes the upload.
|
|
func (w *GCSWriter) Close() error {
|
|
close(w.chunkCh)
|
|
w.wg.Wait()
|
|
|
|
if err := w.err.Load(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(w.xmlMPUParts) == 0 {
|
|
return nil
|
|
}
|
|
err := w.finalizeXMLMPU()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
errC := w.cancel()
|
|
if errC != nil {
|
|
return fmt.Errorf("failed to finalize multipart upload: %s, Failed to cancel multipart upload: %s", err, errC)
|
|
}
|
|
return fmt.Errorf("failed to finalize multipart upload: %s", err)
|
|
}
|
|
|
|
const (
|
|
mpuInitiateQuery = "uploads"
|
|
mpuPartNumberQuery = "partNumber"
|
|
mpuUploadIDQuery = "uploadId"
|
|
)
|
|
|
|
type uploadBase struct {
|
|
cli *storage.Client
|
|
ctx context.Context
|
|
bucket string
|
|
blob string
|
|
retry int
|
|
signedURLExpiry time.Duration
|
|
}
|
|
|
|
const (
|
|
defaultRetry = 3
|
|
defaultSignedURLExpiry = 6 * time.Hour
|
|
|
|
gcsMinimumChunkSize = 5 * 1024 * 1024 // 5 MB
|
|
gcsMaximumChunkSize = 5 * 1024 * 1024 * 1024 // 5 GB
|
|
gcsMaximumParts = 10000
|
|
)
|
|
|
|
// InitiateMultipartUploadResult initiate multipart upload result structure.
|
|
type InitiateMultipartUploadResult struct {
|
|
XMLName xml.Name `xml:"InitiateMultipartUploadResult"`
|
|
Text string `xml:",chardata"`
|
|
Xmlns string `xml:"xmlns,attr"`
|
|
Bucket string `xml:"Bucket"`
|
|
Key string `xml:"Key"`
|
|
UploadID string `xml:"UploadId"`
|
|
}
|
|
|
|
// Part is a part.
|
|
type Part struct {
|
|
Text string `xml:",chardata"`
|
|
PartNumber int `xml:"PartNumber"`
|
|
ETag string `xml:"ETag"`
|
|
}
|
|
|
|
// CompleteMultipartUpload is the complete multipart upload structure.
|
|
type CompleteMultipartUpload struct {
|
|
XMLName xml.Name `xml:"CompleteMultipartUpload"`
|
|
Text string `xml:",chardata"`
|
|
Parts []Part `xml:"Part"`
|
|
}
|
|
|
|
func (w *GCSWriter) finalizeXMLMPU() error {
|
|
finalXMLRoot := CompleteMultipartUpload{
|
|
Parts: make([]Part, 0, len(w.xmlMPUParts)),
|
|
}
|
|
slices.SortFunc(w.xmlMPUParts, func(a, b *xmlMPUPart) int {
|
|
return a.partNumber - b.partNumber
|
|
})
|
|
for _, part := range w.xmlMPUParts {
|
|
part := Part{
|
|
PartNumber: part.partNumber,
|
|
ETag: part.etag,
|
|
}
|
|
finalXMLRoot.Parts = append(finalXMLRoot.Parts, part)
|
|
}
|
|
|
|
xmlBytes, err := xml.Marshal(finalXMLRoot)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encode XML: %v", err)
|
|
}
|
|
|
|
opts := &storage.SignedURLOptions{
|
|
Scheme: storage.SigningSchemeV4,
|
|
Method: "POST",
|
|
Expires: time.Now().Add(w.signedURLExpiry),
|
|
QueryParameters: url.Values{mpuUploadIDQuery: []string{w.uploadID}},
|
|
}
|
|
u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts)
|
|
if err != nil {
|
|
return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err)
|
|
}
|
|
|
|
client := resty.New()
|
|
resp, err := client.R().SetBody(xmlBytes).Post(u)
|
|
if err != nil {
|
|
return fmt.Errorf("POST request failed: %s", err)
|
|
}
|
|
|
|
if resp.StatusCode() != http.StatusOK {
|
|
return fmt.Errorf("POST request returned non-OK status: %d, body: %s", resp.StatusCode(), resp.String())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type chunk struct {
|
|
buf []byte
|
|
num int
|
|
cleanup func()
|
|
}
|
|
|
|
func (w *GCSWriter) appendMPUPart(part *xmlMPUPart) {
|
|
w.mutex.Lock()
|
|
defer w.mutex.Unlock()
|
|
|
|
w.xmlMPUParts = append(w.xmlMPUParts, part)
|
|
}
|
|
|
|
func (w *GCSWriter) cancel() error {
|
|
opts := &storage.SignedURLOptions{
|
|
Scheme: storage.SigningSchemeV4,
|
|
Method: "DELETE",
|
|
Expires: time.Now().Add(w.signedURLExpiry),
|
|
QueryParameters: url.Values{mpuUploadIDQuery: []string{w.uploadID}},
|
|
}
|
|
u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts)
|
|
if err != nil {
|
|
return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err)
|
|
}
|
|
|
|
client := resty.New()
|
|
resp, err := client.R().Delete(u)
|
|
if err != nil {
|
|
return fmt.Errorf("DELETE request failed: %s", err)
|
|
}
|
|
|
|
if resp.StatusCode() != http.StatusNoContent {
|
|
return fmt.Errorf("DELETE request returned non-204 status: %d", resp.StatusCode())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type xmlMPUPart struct {
|
|
uploadBase
|
|
buf []byte
|
|
uploadID string
|
|
partNumber int
|
|
etag string
|
|
}
|
|
|
|
func (p *xmlMPUPart) Clone() *xmlMPUPart {
|
|
return &xmlMPUPart{
|
|
uploadBase: p.uploadBase,
|
|
uploadID: p.uploadID,
|
|
buf: p.buf,
|
|
partNumber: p.partNumber,
|
|
}
|
|
}
|
|
|
|
func (p *xmlMPUPart) Upload() error {
|
|
var err error
|
|
for range p.retry {
|
|
err = p.upload()
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("failed to upload part %d: %w", p.partNumber, err)
|
|
}
|
|
|
|
func (p *xmlMPUPart) upload() error {
|
|
opts := &storage.SignedURLOptions{
|
|
Scheme: storage.SigningSchemeV4,
|
|
Method: "PUT",
|
|
Expires: time.Now().Add(p.signedURLExpiry),
|
|
QueryParameters: url.Values{
|
|
mpuUploadIDQuery: []string{p.uploadID},
|
|
mpuPartNumberQuery: []string{strconv.Itoa(p.partNumber)},
|
|
},
|
|
}
|
|
|
|
u, err := p.cli.Bucket(p.bucket).SignedURL(p.blob, opts)
|
|
if err != nil {
|
|
return fmt.Errorf("Bucket(%q).SignedURL: %s", p.bucket, err)
|
|
}
|
|
|
|
req, err := http.NewRequest("PUT", u, bytes.NewReader(p.buf))
|
|
if err != nil {
|
|
return fmt.Errorf("PUT request failed: %s", err)
|
|
}
|
|
req = req.WithContext(p.ctx)
|
|
|
|
client := &http.Client{
|
|
Transport: createTransport(nil),
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("PUT request failed: %s", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
p.etag = resp.Header.Get("ETag")
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("PUT request returned non-OK status: %d", resp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func createTransport(localAddr net.Addr) *http.Transport {
|
|
dialer := &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
if localAddr != nil {
|
|
dialer.LocalAddr = localAddr
|
|
}
|
|
return &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: dialer.DialContext,
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1,
|
|
}
|
|
}
|