diff --git a/cmd/server.go b/cmd/server.go index d9206cfe..4263f020 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -4,9 +4,6 @@ import ( "context" "errors" "fmt" - ftpserver "github.com/KirCute/ftpserverlib-pasvportmap" - "github.com/KirCute/sftpd-alist" - "github.com/alist-org/alist/v3/internal/fs" "net" "net/http" "os" @@ -16,14 +13,19 @@ import ( "syscall" "time" + ftpserver "github.com/KirCute/ftpserverlib-pasvportmap" + "github.com/KirCute/sftpd-alist" "github.com/alist-org/alist/v3/cmd/flags" "github.com/alist-org/alist/v3/internal/bootstrap" "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) // ServerCmd represents the server command @@ -47,11 +49,15 @@ the address is defined in config file`, r := gin.New() r.Use(gin.LoggerWithWriter(log.StandardLogger().Out), gin.RecoveryWithWriter(log.StandardLogger().Out)) server.Init(r) + var httpHandler http.Handler = r + if conf.Conf.Scheme.EnableH2c { + httpHandler = h2c.NewHandler(r, &http2.Server{}) + } var httpSrv, httpsSrv, unixSrv *http.Server if conf.Conf.Scheme.HttpPort != -1 { httpBase := fmt.Sprintf("%s:%d", conf.Conf.Scheme.Address, conf.Conf.Scheme.HttpPort) utils.Log.Infof("start HTTP server @ %s", httpBase) - httpSrv = &http.Server{Addr: httpBase, Handler: r} + httpSrv = &http.Server{Addr: httpBase, Handler: httpHandler} go func() { err := httpSrv.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { @@ -72,7 +78,7 @@ the address is defined in config file`, } if conf.Conf.Scheme.UnixFile != "" { utils.Log.Infof("start unix server @ %s", conf.Conf.Scheme.UnixFile) - unixSrv = &http.Server{Handler: r} + unixSrv = &http.Server{Handler: httpHandler} go func() { listener, err := net.Listen("unix", conf.Conf.Scheme.UnixFile) if err != nil { diff --git a/drivers/115/util.go b/drivers/115/util.go index 7298f565..fc17fe3c 100644 --- a/drivers/115/util.go +++ b/drivers/115/util.go @@ -405,7 +405,7 @@ func (d *Pan115) UploadByMultipart(ctx context.Context, params *driver115.Upload if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) { continue } - if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)), + if part, err = bucket.UploadPart(imur, driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf)), chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil { break } diff --git a/drivers/123/driver.go b/drivers/123/driver.go index 7d457138..32c053e2 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -2,11 +2,8 @@ package _123 import ( "context" - "crypto/md5" "encoding/base64" - "encoding/hex" "fmt" - "io" "net/http" "net/url" "sync" @@ -18,6 +15,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -187,25 +185,12 @@ func (d *Pan123) Remove(ctx context.Context, obj model.Obj) error { func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { etag := file.GetHash().GetHash(utils.MD5) + var err error if len(etag) < utils.MD5.Width { - // const DEFAULT int64 = 10485760 - h := md5.New() - // need to calculate md5 of the full content - tempFile, err := file.CacheFullInTempFile() + _, etag, err = stream.CacheFullInTempFileAndHash(file, utils.MD5) if err != nil { return err } - defer func() { - _ = tempFile.Close() - }() - if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } - etag = hex.EncodeToString(h.Sum(nil)) } data := base.Json{ "driveId": 0, diff --git a/drivers/123/upload.go b/drivers/123/upload.go index dc148c4c..b0482a9f 100644 --- a/drivers/123/upload.go +++ b/drivers/123/upload.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "math" "net/http" "strconv" @@ -70,27 +69,33 @@ func (d *Pan123) completeS3(ctx context.Context, upReq *UploadResp, file model.F } func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, up driver.UpdateProgress) error { - chunkSize := int64(1024 * 1024 * 16) + tmpF, err := file.CacheFullInTempFile() + if err != nil { + return err + } // fetch s3 pre signed urls - chunkCount := int(math.Ceil(float64(file.GetSize()) / float64(chunkSize))) + size := file.GetSize() + chunkSize := min(size, 16*utils.MB) + chunkCount := int(size / chunkSize) + lastChunkSize := size % chunkSize + if lastChunkSize > 0 { + chunkCount++ + } else { + lastChunkSize = chunkSize + } // only 1 batch is allowed - isMultipart := chunkCount > 1 batchSize := 1 getS3UploadUrl := d.getS3Auth - if isMultipart { + if chunkCount > 1 { batchSize = 10 getS3UploadUrl = d.getS3PreSignedUrls } - limited := driver.NewLimitedUploadStream(ctx, file) for i := 1; i <= chunkCount; i += batchSize { if utils.IsCanceled(ctx) { return ctx.Err() } start := i - end := i + batchSize - if end > chunkCount+1 { - end = chunkCount + 1 - } + end := min(i+batchSize, chunkCount+1) s3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, start, end) if err != nil { return err @@ -102,9 +107,9 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi } curSize := chunkSize if j == chunkCount { - curSize = file.GetSize() - (int64(chunkCount)-1)*chunkSize + curSize = lastChunkSize } - err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(limited, chunkSize), curSize, false, getS3UploadUrl) + err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.NewSectionReader(tmpF, chunkSize*int64(j-1), curSize), curSize, false, getS3UploadUrl) if err != nil { return err } @@ -115,12 +120,12 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi return d.completeS3(ctx, upReq, file, chunkCount > 1) } -func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader io.Reader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error { +func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader *io.SectionReader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error { uploadUrl := s3PreSignedUrls.Data.PreSignedUrls[strconv.Itoa(cur)] if uploadUrl == "" { return fmt.Errorf("upload url is empty, s3PreSignedUrls: %+v", s3PreSignedUrls) } - req, err := http.NewRequest("PUT", uploadUrl, reader) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, reader)) if err != nil { return err } @@ -143,6 +148,7 @@ func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSign } s3PreSignedUrls.Data.PreSignedUrls = newS3PreSignedUrls.Data.PreSignedUrls // retry + reader.Seek(0, io.SeekStart) return d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, cur, end, reader, curSize, true, getS3UploadUrl) } if res.StatusCode != http.StatusOK { diff --git a/drivers/139/driver.go b/drivers/139/driver.go index e45f082d..a9c59f72 100644 --- a/drivers/139/driver.go +++ b/drivers/139/driver.go @@ -16,6 +16,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + streamPkg "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/cron" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils/random" @@ -120,7 +121,7 @@ func (d *Yun139) Init(ctx context.Context) error { } } - return err + return nil } func (d *Yun139) InitReference(storage driver.Driver) error { @@ -530,23 +531,15 @@ func (d *Yun139) Remove(ctx context.Context, obj model.Obj) error { } } -const ( - _ = iota //ignore first value by assigning to blank identifier - KB = 1 << (10 * iota) - MB - GB - TB -) - func (d *Yun139) getPartSize(size int64) int64 { if d.CustomUploadPartSize != 0 { return d.CustomUploadPartSize } // 网盘对于分片数量存在上限 - if size/GB > 30 { - return 512 * MB + if size/utils.GB > 30 { + return 512 * utils.MB } - return 100 * MB + return 100 * utils.MB } func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { @@ -554,29 +547,28 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr case MetaPersonalNew: var err error fullHash := stream.GetHash().GetHash(utils.SHA256) - if len(fullHash) <= 0 { - tmpF, err := stream.CacheFullInTempFile() - if err != nil { - return err - } - fullHash, err = utils.HashFile(utils.SHA256, tmpF) + if len(fullHash) != utils.SHA256.Width { + _, fullHash, err = streamPkg.CacheFullInTempFileAndHash(stream, utils.SHA256) if err != nil { return err } } - partInfos := []PartInfo{} - var partSize = d.getPartSize(stream.GetSize()) - part := (stream.GetSize() + partSize - 1) / partSize - if part == 0 { + size := stream.GetSize() + var partSize = d.getPartSize(size) + part := size / partSize + if size%partSize > 0 { + part++ + } else if part == 0 { part = 1 } + partInfos := make([]PartInfo, 0, part) for i := int64(0); i < part; i++ { if utils.IsCanceled(ctx) { return ctx.Err() } start := i * partSize - byteSize := stream.GetSize() - start + byteSize := size - start if byteSize > partSize { byteSize = partSize } @@ -604,7 +596,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr "contentType": "application/octet-stream", "parallelUpload": false, "partInfos": firstPartInfos, - "size": stream.GetSize(), + "size": size, "parentFileId": dstDir.GetID(), "name": stream.GetName(), "type": "file", @@ -657,7 +649,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr } // Progress - p := driver.NewProgress(stream.GetSize(), up) + p := driver.NewProgress(size, up) rateLimited := driver.NewLimitedUploadStream(ctx, stream) // 上传所有分片 @@ -817,12 +809,14 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr return fmt.Errorf("get file upload url failed with result code: %s, message: %s", resp.Data.Result.ResultCode, resp.Data.Result.ResultDesc) } + size := stream.GetSize() // Progress - p := driver.NewProgress(stream.GetSize(), up) - - var partSize = d.getPartSize(stream.GetSize()) - part := (stream.GetSize() + partSize - 1) / partSize - if part == 0 { + p := driver.NewProgress(size, up) + var partSize = d.getPartSize(size) + part := size / partSize + if size%partSize > 0 { + part++ + } else if part == 0 { part = 1 } rateLimited := driver.NewLimitedUploadStream(ctx, stream) @@ -832,10 +826,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr } start := i * partSize - byteSize := stream.GetSize() - start - if byteSize > partSize { - byteSize = partSize - } + byteSize := min(size-start, partSize) limitReader := io.LimitReader(rateLimited, byteSize) // Update Progress @@ -847,7 +838,7 @@ func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr req = req.WithContext(ctx) req.Header.Set("Content-Type", "text/plain;name="+unicode(stream.GetName())) - req.Header.Set("contentSize", strconv.FormatInt(stream.GetSize(), 10)) + req.Header.Set("contentSize", strconv.FormatInt(size, 10)) req.Header.Set("range", fmt.Sprintf("bytes=%d-%d", start, start+byteSize-1)) req.Header.Set("uploadtaskID", resp.Data.UploadResult.UploadTaskID) req.Header.Set("rangeType", "0") diff --git a/drivers/139/util.go b/drivers/139/util.go index a5371fda..91458436 100644 --- a/drivers/139/util.go +++ b/drivers/139/util.go @@ -67,6 +67,7 @@ func (d *Yun139) refreshToken() error { if len(splits) < 3 { return fmt.Errorf("authorization is invalid, splits < 3") } + d.Account = splits[1] strs := strings.Split(splits[2], "|") if len(strs) < 4 { return fmt.Errorf("authorization is invalid, strs < 4") diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index fb1a183a..c391f7e6 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -3,16 +3,15 @@ package _189pc import ( "bytes" "context" - "crypto/md5" "encoding/base64" "encoding/hex" "encoding/xml" "fmt" "io" - "math" "net/http" "net/http/cookiejar" "net/url" + "os" "regexp" "sort" "strconv" @@ -28,6 +27,7 @@ import ( "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/errgroup" "github.com/alist-org/alist/v3/pkg/utils" @@ -473,12 +473,8 @@ func (y *Cloud189PC) refreshSession() (err error) { // 普通上传 // 无法上传大小为0的文件 func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { - var sliceSize = partSize(file.GetSize()) - count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize))) - lastPartSize := file.GetSize() % sliceSize - if file.GetSize() > 0 && lastPartSize == 0 { - lastPartSize = sliceSize - } + size := file.GetSize() + sliceSize := partSize(size) params := Params{ "parentFolderId": dstDir.GetID(), @@ -512,22 +508,29 @@ func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file mo retry.DelayType(retry.BackOffDelay)) sem := semaphore.NewWeighted(3) - fileMd5 := md5.New() - silceMd5 := md5.New() + count := int(size / sliceSize) + lastPartSize := size % sliceSize + if lastPartSize > 0 { + count++ + } else { + lastPartSize = sliceSize + } + fileMd5 := utils.MD5.NewFunc() + silceMd5 := utils.MD5.NewFunc() silceMd5Hexs := make([]string, 0, count) - + teeReader := io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)) + byteSize := sliceSize for i := 1; i <= count; i++ { if utils.IsCanceled(upCtx) { break } - byteData := make([]byte, sliceSize) if i == count { - byteData = byteData[:lastPartSize] + byteSize = lastPartSize } - + byteData := make([]byte, byteSize) // 读取块 silceMd5.Reset() - if _, err := io.ReadFull(io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)), byteData); err != io.EOF && err != nil { + if _, err := io.ReadFull(teeReader, byteData); err != io.EOF && err != nil { sem.Release(1) return nil, err } @@ -607,24 +610,43 @@ func (y *Cloud189PC) RapidUpload(ctx context.Context, dstDir model.Obj, stream m // 快传 func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { - tempFile, err := file.CacheFullInTempFile() - if err != nil { - return nil, err + var ( + cache = file.GetFile() + tmpF *os.File + err error + ) + size := file.GetSize() + if _, ok := cache.(io.ReaderAt); !ok && size > 0 { + tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return nil, err + } + defer func() { + _ = tmpF.Close() + _ = os.Remove(tmpF.Name()) + }() + cache = tmpF } - - var sliceSize = partSize(file.GetSize()) - count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize))) - lastSliceSize := file.GetSize() % sliceSize - if file.GetSize() > 0 && lastSliceSize == 0 { + sliceSize := partSize(size) + count := int(size / sliceSize) + lastSliceSize := size % sliceSize + if lastSliceSize > 0 { + count++ + } else { lastSliceSize = sliceSize } //step.1 优先计算所需信息 byteSize := sliceSize - fileMd5 := md5.New() - silceMd5 := md5.New() - silceMd5Hexs := make([]string, 0, count) + fileMd5 := utils.MD5.NewFunc() + sliceMd5 := utils.MD5.NewFunc() + sliceMd5Hexs := make([]string, 0, count) partInfos := make([]string, 0, count) + writers := []io.Writer{fileMd5, sliceMd5} + if tmpF != nil { + writers = append(writers, tmpF) + } + written := int64(0) for i := 1; i <= count; i++ { if utils.IsCanceled(ctx) { return nil, ctx.Err() @@ -634,19 +656,31 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode byteSize = lastSliceSize } - silceMd5.Reset() - if _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5, silceMd5), tempFile, byteSize); err != nil && err != io.EOF { + n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), file, byteSize) + written += n + if err != nil && err != io.EOF { return nil, err } - md5Byte := silceMd5.Sum(nil) - silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte))) + md5Byte := sliceMd5.Sum(nil) + sliceMd5Hexs = append(sliceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte))) partInfos = append(partInfos, fmt.Sprint(i, "-", base64.StdEncoding.EncodeToString(md5Byte))) + sliceMd5.Reset() + } + + if tmpF != nil { + if size > 0 && written != size { + return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, size) + } + _, err = tmpF.Seek(0, io.SeekStart) + if err != nil { + return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") + } } fileMd5Hex := strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil))) sliceMd5Hex := fileMd5Hex - if file.GetSize() > sliceSize { - sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(silceMd5Hexs, "\n"))) + if size > sliceSize { + sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(sliceMd5Hexs, "\n"))) } fullUrl := UPLOAD_URL @@ -712,7 +746,7 @@ func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file mode } // step.4 上传切片 - _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(tempFile, offset, byteSize), isFamily) + _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(cache, offset, byteSize), isFamily) if err != nil { return err } @@ -794,11 +828,7 @@ func (y *Cloud189PC) GetMultiUploadUrls(ctx context.Context, isFamily bool, uplo // 旧版本上传,家庭云不支持覆盖 func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { - tempFile, err := file.CacheFullInTempFile() - if err != nil { - return nil, err - } - fileMd5, err := utils.HashFile(utils.MD5, tempFile) + tempFile, fileMd5, err := stream.CacheFullInTempFileAndHash(file, utils.MD5) if err != nil { return nil, err } diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go index fb730de6..4114c195 100644 --- a/drivers/aliyundrive_open/upload.go +++ b/drivers/aliyundrive_open/upload.go @@ -1,7 +1,6 @@ package aliyundrive_open import ( - "bytes" "context" "encoding/base64" "fmt" @@ -15,6 +14,7 @@ import ( "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" + streamPkg "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/http_range" "github.com/alist-org/alist/v3/pkg/utils" "github.com/avast/retry-go" @@ -131,16 +131,19 @@ func (d *AliyundriveOpen) calProofCode(stream model.FileStreamer) (string, error return "", err } length := proofRange.End - proofRange.Start - buf := bytes.NewBuffer(make([]byte, 0, length)) reader, err := stream.RangeRead(http_range.Range{Start: proofRange.Start, Length: length}) if err != nil { return "", err } - _, err = utils.CopyWithBufferN(buf, reader, length) + buf := make([]byte, length) + n, err := io.ReadFull(reader, buf) + if err == io.ErrUnexpectedEOF { + return "", fmt.Errorf("can't read data, expected=%d, got=%d", len(buf), n) + } if err != nil { return "", err } - return base64.StdEncoding.EncodeToString(buf.Bytes()), nil + return base64.StdEncoding.EncodeToString(buf), nil } func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { @@ -183,25 +186,18 @@ func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream m _, err, e := d.requestReturnErrResp("/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) { req.SetBody(createData).SetResult(&createResp) }) - var tmpF model.File if err != nil { if e.Code != "PreHashMatched" || !rapidUpload { return nil, err } log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload") - hi := stream.GetHash() - hash := hi.GetHash(utils.SHA1) - if len(hash) <= 0 { - tmpF, err = stream.CacheFullInTempFile() + hash := stream.GetHash().GetHash(utils.SHA1) + if len(hash) != utils.SHA1.Width { + _, hash, err = streamPkg.CacheFullInTempFileAndHash(stream, utils.SHA1) if err != nil { return nil, err } - hash, err = utils.HashFile(utils.SHA1, tmpF) - if err != nil { - return nil, err - } - } delete(createData, "pre_hash") diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 3cc1ae9e..c33e0b32 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -6,8 +6,8 @@ import ( "encoding/hex" "errors" "io" - "math" "net/url" + "os" stdpath "path" "strconv" "time" @@ -15,6 +15,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -185,16 +186,30 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F return newObj, nil } - tempFile, err := stream.CacheFullInTempFile() - if err != nil { - return nil, err + var ( + cache = stream.GetFile() + tmpF *os.File + err error + ) + if _, ok := cache.(io.ReaderAt); !ok { + tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return nil, err + } + defer func() { + _ = tmpF.Close() + _ = os.Remove(tmpF.Name()) + }() + cache = tmpF } streamSize := stream.GetSize() sliceSize := d.getSliceSize(streamSize) - count := int(math.Max(math.Ceil(float64(streamSize)/float64(sliceSize)), 1)) + count := int(streamSize / sliceSize) lastBlockSize := streamSize % sliceSize - if streamSize > 0 && lastBlockSize == 0 { + if lastBlockSize > 0 { + count++ + } else { lastBlockSize = sliceSize } @@ -207,6 +222,11 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F sliceMd5H := md5.New() sliceMd5H2 := md5.New() slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) + writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write} + if tmpF != nil { + writers = append(writers, tmpF) + } + written := int64(0) for i := 1; i <= count; i++ { if utils.IsCanceled(ctx) { @@ -215,13 +235,23 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F if i == count { byteSize = lastBlockSize } - _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) + n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize) + written += n if err != nil && err != io.EOF { return nil, err } blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil))) sliceMd5H.Reset() } + if tmpF != nil { + if written != streamSize { + return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize) + } + _, err = tmpF.Seek(0, io.SeekStart) + if err != nil { + return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") + } + } contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) blockListStr, _ := utils.Json.MarshalToString(blockList) @@ -291,7 +321,7 @@ func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.F "partseq": strconv.Itoa(partseq), } err := d.uploadSlice(ctx, params, stream.GetName(), - driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize))) + driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize))) if err != nil { return err } diff --git a/drivers/baidu_photo/driver.go b/drivers/baidu_photo/driver.go index eeee746f..5a34fcb4 100644 --- a/drivers/baidu_photo/driver.go +++ b/drivers/baidu_photo/driver.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "io" - "math" + "os" "regexp" "strconv" "strings" @@ -16,6 +16,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -241,11 +242,21 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil // TODO: // 暂时没有找到妙传方式 - - // 需要获取完整文件md5,必须支持 io.Seek - tempFile, err := stream.CacheFullInTempFile() - if err != nil { - return nil, err + var ( + cache = stream.GetFile() + tmpF *os.File + err error + ) + if _, ok := cache.(io.ReaderAt); !ok { + tmpF, err = os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return nil, err + } + defer func() { + _ = tmpF.Close() + _ = os.Remove(tmpF.Name()) + }() + cache = tmpF } const DEFAULT int64 = 1 << 22 @@ -253,9 +264,11 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil // 计算需要的数据 streamSize := stream.GetSize() - count := int(math.Ceil(float64(streamSize) / float64(DEFAULT))) + count := int(streamSize / DEFAULT) lastBlockSize := streamSize % DEFAULT - if lastBlockSize == 0 { + if lastBlockSize > 0 { + count++ + } else { lastBlockSize = DEFAULT } @@ -266,6 +279,11 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil sliceMd5H := md5.New() sliceMd5H2 := md5.New() slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) + writers := []io.Writer{fileMd5H, sliceMd5H, slicemd5H2Write} + if tmpF != nil { + writers = append(writers, tmpF) + } + written := int64(0) for i := 1; i <= count; i++ { if utils.IsCanceled(ctx) { return nil, ctx.Err() @@ -273,13 +291,23 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil if i == count { byteSize = lastBlockSize } - _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) + n, err := utils.CopyWithBufferN(io.MultiWriter(writers...), stream, byteSize) + written += n if err != nil && err != io.EOF { return nil, err } sliceMD5List = append(sliceMD5List, hex.EncodeToString(sliceMd5H.Sum(nil))) sliceMd5H.Reset() } + if tmpF != nil { + if written != streamSize { + return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", written, streamSize) + } + _, err = tmpF.Seek(0, io.SeekStart) + if err != nil { + return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") + } + } contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) blockListStr, _ := utils.Json.MarshalToString(sliceMD5List) @@ -291,7 +319,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil "rtype": "1", "ctype": "11", "path": fmt.Sprintf("/%s", stream.GetName()), - "size": fmt.Sprint(stream.GetSize()), + "size": fmt.Sprint(streamSize), "slice-md5": sliceMd5, "content-md5": contentMd5, "block_list": blockListStr, @@ -343,7 +371,7 @@ func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.Fil r.SetContext(ctx) r.SetQueryParams(uploadParams) r.SetFileReader("file", stream.GetName(), - driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, byteSize))) + driver.NewLimitedUploadStream(ctx, io.NewSectionReader(cache, offset, byteSize))) }, nil) if err != nil { return err diff --git a/drivers/cloudreve/util.go b/drivers/cloudreve/util.go index 1fd5ed8a..196d7303 100644 --- a/drivers/cloudreve/util.go +++ b/drivers/cloudreve/util.go @@ -204,7 +204,7 @@ func (d *Cloudreve) upLocal(ctx context.Context, stream model.FileStreamer, u Up req.SetContentLength(true) req.SetHeader("Content-Length", strconv.FormatInt(byteSize, 10)) req.SetHeader("User-Agent", d.getUA()) - req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + req.SetBody(driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) }, nil) if err != nil { break @@ -239,7 +239,7 @@ func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u U return err } req, err := http.NewRequest("POST", uploadUrl+"?chunk="+strconv.Itoa(chunk), - driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } @@ -280,7 +280,7 @@ func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } diff --git a/drivers/doubao/driver.go b/drivers/doubao/driver.go index 04f74325..a066feee 100644 --- a/drivers/doubao/driver.go +++ b/drivers/doubao/driver.go @@ -3,19 +3,25 @@ package doubao import ( "context" "errors" - "time" - "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" "github.com/google/uuid" + "net/http" + "strconv" + "strings" + "time" ) type Doubao struct { model.Storage Addition + *UploadToken + UserId string + uploadThread int } func (d *Doubao) Config() driver.Config { @@ -29,6 +35,31 @@ func (d *Doubao) GetAddition() driver.Additional { func (d *Doubao) Init(ctx context.Context) error { // TODO login / refresh token //op.MustSaveDriverStorage(d) + uploadThread, err := strconv.Atoi(d.UploadThread) + if err != nil || uploadThread < 1 { + d.uploadThread, d.UploadThread = 3, "3" // Set default value + } else { + d.uploadThread = uploadThread + } + + if d.UserId == "" { + userInfo, err := d.getUserInfo() + if err != nil { + return err + } + + d.UserId = strconv.FormatInt(userInfo.UserID, 10) + } + + if d.UploadToken == nil { + uploadToken, err := d.initUploadToken() + if err != nil { + return err + } + + d.UploadToken = uploadToken + } + return nil } @@ -38,18 +69,12 @@ func (d *Doubao) Drop(ctx context.Context) error { func (d *Doubao) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { var files []model.Obj - var r NodeInfoResp - _, err := d.request("/samantha/aispace/node_info", "POST", func(req *resty.Request) { - req.SetBody(base.Json{ - "node_id": dir.GetID(), - "need_full_path": false, - }) - }, &r) + fileList, err := d.getFiles(dir.GetID(), "") if err != nil { return nil, err } - for _, child := range r.Data.Children { + for _, child := range fileList { files = append(files, &Object{ Object: model.Object{ ID: child.ID, @@ -60,34 +85,65 @@ func (d *Doubao) List(ctx context.Context, dir model.Obj, args model.ListArgs) ( Ctime: time.Unix(child.CreateTime, 0), IsFolder: child.NodeType == 1, }, - Key: child.Key, + Key: child.Key, + NodeType: child.NodeType, }) } + return files, nil } func (d *Doubao) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var downloadUrl string + if u, ok := file.(*Object); ok { - var r GetFileUrlResp - _, err := d.request("/alice/message/get_file_url", "POST", func(req *resty.Request) { - req.SetBody(base.Json{ - "uris": []string{u.Key}, - "type": "file", - }) - }, &r) - if err != nil { - return nil, err + switch u.NodeType { + case VideoType, AudioType: + var r GetVideoFileUrlResp + _, err := d.request("/samantha/media/get_play_info", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "key": u.Key, + "node_id": file.GetID(), + }) + }, &r) + if err != nil { + return nil, err + } + + downloadUrl = r.Data.OriginalMediaInfo.MainURL + default: + var r GetFileUrlResp + _, err := d.request("/alice/message/get_file_url", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "uris": []string{u.Key}, + "type": FileNodeType[u.NodeType], + }) + }, &r) + if err != nil { + return nil, err + } + + downloadUrl = r.Data.FileUrls[0].MainURL } + + // 生成标准的Content-Disposition + contentDisposition := generateContentDisposition(u.Name) + return &model.Link{ - URL: r.Data.FileUrls[0].MainURL, + URL: downloadUrl, + Header: http.Header{ + "User-Agent": []string{UserAgent}, + "Content-Disposition": []string{contentDisposition}, + }, }, nil } + return nil, errors.New("can't convert obj to URL") } func (d *Doubao) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { var r UploadNodeResp - _, err := d.request("/samantha/aispace/upload_node", "POST", func(req *resty.Request) { + _, err := d.request("/samantha/aispace/upload_node", http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{ "node_list": []base.Json{ { @@ -104,7 +160,7 @@ func (d *Doubao) MakeDir(ctx context.Context, parentDir model.Obj, dirName strin func (d *Doubao) Move(ctx context.Context, srcObj, dstDir model.Obj) error { var r UploadNodeResp - _, err := d.request("/samantha/aispace/move_node", "POST", func(req *resty.Request) { + _, err := d.request("/samantha/aispace/move_node", http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{ "node_list": []base.Json{ {"id": srcObj.GetID()}, @@ -118,7 +174,7 @@ func (d *Doubao) Move(ctx context.Context, srcObj, dstDir model.Obj) error { func (d *Doubao) Rename(ctx context.Context, srcObj model.Obj, newName string) error { var r BaseResp - _, err := d.request("/samantha/aispace/rename_node", "POST", func(req *resty.Request) { + _, err := d.request("/samantha/aispace/rename_node", http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{ "node_id": srcObj.GetID(), "node_name": newName, @@ -134,15 +190,38 @@ func (d *Doubao) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, func (d *Doubao) Remove(ctx context.Context, obj model.Obj) error { var r BaseResp - _, err := d.request("/samantha/aispace/delete_node", "POST", func(req *resty.Request) { + _, err := d.request("/samantha/aispace/delete_node", http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{"node_list": []base.Json{{"id": obj.GetID()}}}) }, &r) return err } func (d *Doubao) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - // TODO upload file, optional - return nil, errs.NotImplement + // 根据MIME类型确定数据类型 + mimetype := file.GetMimetype() + dataType := FileDataType + + switch { + case strings.HasPrefix(mimetype, "video/"): + dataType = VideoDataType + case strings.HasPrefix(mimetype, "audio/"): + dataType = VideoDataType // 音频与视频使用相同的处理方式 + case strings.HasPrefix(mimetype, "image/"): + dataType = ImgDataType + } + + // 获取上传配置 + uploadConfig := UploadConfig{} + if err := d.getUploadConfig(&uploadConfig, dataType, file); err != nil { + return nil, err + } + + // 根据文件大小选择上传方式 + if file.GetSize() <= 1*utils.MB { // 小于1MB,使用普通模式上传 + return d.Upload(&uploadConfig, dstDir, file, up, dataType) + } + // 大文件使用分片上传 + return d.UploadByMultipart(ctx, &uploadConfig, file.GetSize(), dstDir, file, up, dataType) } func (d *Doubao) GetArchiveMeta(ctx context.Context, obj model.Obj, args model.ArchiveArgs) (model.ArchiveMeta, error) { diff --git a/drivers/doubao/meta.go b/drivers/doubao/meta.go index bb9e3f25..c3d8eb34 100644 --- a/drivers/doubao/meta.go +++ b/drivers/doubao/meta.go @@ -10,7 +10,8 @@ type Addition struct { // driver.RootPath driver.RootID // define other - Cookie string `json:"cookie" type:"text"` + Cookie string `json:"cookie" type:"text"` + UploadThread string `json:"upload_thread" default:"3"` } var config = driver.Config{ @@ -19,7 +20,7 @@ var config = driver.Config{ OnlyLocal: false, OnlyProxy: false, NoCache: false, - NoUpload: true, + NoUpload: false, NeedMs: false, DefaultRoot: "0", CheckStatus: false, diff --git a/drivers/doubao/types.go b/drivers/doubao/types.go index 2dc5a61d..4264eb7d 100644 --- a/drivers/doubao/types.go +++ b/drivers/doubao/types.go @@ -1,6 +1,11 @@ package doubao -import "github.com/alist-org/alist/v3/internal/model" +import ( + "encoding/json" + "fmt" + "github.com/alist-org/alist/v3/internal/model" + "time" +) type BaseResp struct { Code int `json:"code"` @@ -10,14 +15,14 @@ type BaseResp struct { type NodeInfoResp struct { BaseResp Data struct { - NodeInfo NodeInfo `json:"node_info"` - Children []NodeInfo `json:"children"` - NextCursor string `json:"next_cursor"` - HasMore bool `json:"has_more"` + NodeInfo File `json:"node_info"` + Children []File `json:"children"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` } `json:"data"` } -type NodeInfo struct { +type File struct { ID string `json:"id"` Name string `json:"name"` Key string `json:"key"` @@ -44,6 +49,39 @@ type GetFileUrlResp struct { } `json:"data"` } +type GetVideoFileUrlResp struct { + BaseResp + Data struct { + MediaType string `json:"media_type"` + MediaInfo []struct { + Meta struct { + Height string `json:"height"` + Width string `json:"width"` + Format string `json:"format"` + Duration float64 `json:"duration"` + CodecType string `json:"codec_type"` + Definition string `json:"definition"` + } `json:"meta"` + MainURL string `json:"main_url"` + BackupURL string `json:"backup_url"` + } `json:"media_info"` + OriginalMediaInfo struct { + Meta struct { + Height string `json:"height"` + Width string `json:"width"` + Format string `json:"format"` + Duration float64 `json:"duration"` + CodecType string `json:"codec_type"` + Definition string `json:"definition"` + } `json:"meta"` + MainURL string `json:"main_url"` + BackupURL string `json:"backup_url"` + } `json:"original_media_info"` + PosterURL string `json:"poster_url"` + PlayableStatus int `json:"playable_status"` + } `json:"data"` +} + type UploadNodeResp struct { BaseResp Data struct { @@ -60,5 +98,306 @@ type UploadNodeResp struct { type Object struct { model.Object - Key string + Key string + NodeType int +} + +type UserInfoResp struct { + Data UserInfo `json:"data"` + Message string `json:"message"` +} +type AppUserInfo struct { + BuiAuditInfo string `json:"bui_audit_info"` +} +type AuditInfo struct { +} +type Details struct { +} +type BuiAuditInfo struct { + AuditInfo AuditInfo `json:"audit_info"` + IsAuditing bool `json:"is_auditing"` + AuditStatus int `json:"audit_status"` + LastUpdateTime int `json:"last_update_time"` + UnpassReason string `json:"unpass_reason"` + Details Details `json:"details"` +} +type Connects struct { + Platform string `json:"platform"` + ProfileImageURL string `json:"profile_image_url"` + ExpiredTime int `json:"expired_time"` + ExpiresIn int `json:"expires_in"` + PlatformScreenName string `json:"platform_screen_name"` + UserID int64 `json:"user_id"` + PlatformUID string `json:"platform_uid"` + SecPlatformUID string `json:"sec_platform_uid"` + PlatformAppID int `json:"platform_app_id"` + ModifyTime int `json:"modify_time"` + AccessToken string `json:"access_token"` + OpenID string `json:"open_id"` +} +type OperStaffRelationInfo struct { + HasPassword int `json:"has_password"` + Mobile string `json:"mobile"` + SecOperStaffUserID string `json:"sec_oper_staff_user_id"` + RelationMobileCountryCode int `json:"relation_mobile_country_code"` +} +type UserInfo struct { + AppID int `json:"app_id"` + AppUserInfo AppUserInfo `json:"app_user_info"` + AvatarURL string `json:"avatar_url"` + BgImgURL string `json:"bg_img_url"` + BuiAuditInfo BuiAuditInfo `json:"bui_audit_info"` + CanBeFoundByPhone int `json:"can_be_found_by_phone"` + Connects []Connects `json:"connects"` + CountryCode int `json:"country_code"` + Description string `json:"description"` + DeviceID int `json:"device_id"` + Email string `json:"email"` + EmailCollected bool `json:"email_collected"` + Gender int `json:"gender"` + HasPassword int `json:"has_password"` + HmRegion int `json:"hm_region"` + IsBlocked int `json:"is_blocked"` + IsBlocking int `json:"is_blocking"` + IsRecommendAllowed int `json:"is_recommend_allowed"` + IsVisitorAccount bool `json:"is_visitor_account"` + Mobile string `json:"mobile"` + Name string `json:"name"` + NeedCheckBindStatus bool `json:"need_check_bind_status"` + OdinUserType int `json:"odin_user_type"` + OperStaffRelationInfo OperStaffRelationInfo `json:"oper_staff_relation_info"` + PhoneCollected bool `json:"phone_collected"` + RecommendHintMessage string `json:"recommend_hint_message"` + ScreenName string `json:"screen_name"` + SecUserID string `json:"sec_user_id"` + SessionKey string `json:"session_key"` + UseHmRegion bool `json:"use_hm_region"` + UserCreateTime int `json:"user_create_time"` + UserID int64 `json:"user_id"` + UserIDStr string `json:"user_id_str"` + UserVerified bool `json:"user_verified"` + VerifiedContent string `json:"verified_content"` +} + +// UploadToken 上传令牌配置 +type UploadToken struct { + Alice map[string]UploadAuthToken + Samantha MediaUploadAuthToken +} + +// UploadAuthToken 多种类型的上传配置:图片/文件 +type UploadAuthToken struct { + ServiceID string `json:"service_id"` + UploadPathPrefix string `json:"upload_path_prefix"` + Auth struct { + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + SessionToken string `json:"session_token"` + ExpiredTime time.Time `json:"expired_time"` + CurrentTime time.Time `json:"current_time"` + } `json:"auth"` + UploadHost string `json:"upload_host"` +} + +// MediaUploadAuthToken 媒体上传配置 +type MediaUploadAuthToken struct { + StsToken struct { + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + SessionToken string `json:"session_token"` + ExpiredTime time.Time `json:"expired_time"` + CurrentTime time.Time `json:"current_time"` + } `json:"sts_token"` + UploadInfo struct { + VideoHost string `json:"video_host"` + SpaceName string `json:"space_name"` + } `json:"upload_info"` +} + +type UploadAuthTokenResp struct { + BaseResp + Data UploadAuthToken `json:"data"` +} + +type MediaUploadAuthTokenResp struct { + BaseResp + Data MediaUploadAuthToken `json:"data"` +} + +type ResponseMetadata struct { + RequestID string `json:"RequestId"` + Action string `json:"Action"` + Version string `json:"Version"` + Service string `json:"Service"` + Region string `json:"Region"` + Error struct { + CodeN int `json:"CodeN,omitempty"` + Code string `json:"Code,omitempty"` + Message string `json:"Message,omitempty"` + } `json:"Error,omitempty"` +} + +type UploadConfig struct { + UploadAddress UploadAddress `json:"UploadAddress"` + FallbackUploadAddress FallbackUploadAddress `json:"FallbackUploadAddress"` + InnerUploadAddress InnerUploadAddress `json:"InnerUploadAddress"` + RequestID string `json:"RequestId"` + SDKParam interface{} `json:"SDKParam"` +} + +type UploadConfigResp struct { + ResponseMetadata `json:"ResponseMetadata"` + Result UploadConfig `json:"Result"` +} + +// StoreInfo 存储信息 +type StoreInfo struct { + StoreURI string `json:"StoreUri"` + Auth string `json:"Auth"` + UploadID string `json:"UploadID"` + UploadHeader map[string]interface{} `json:"UploadHeader,omitempty"` + StorageHeader map[string]interface{} `json:"StorageHeader,omitempty"` +} + +// UploadAddress 上传地址信息 +type UploadAddress struct { + StoreInfos []StoreInfo `json:"StoreInfos"` + UploadHosts []string `json:"UploadHosts"` + UploadHeader map[string]interface{} `json:"UploadHeader"` + SessionKey string `json:"SessionKey"` + Cloud string `json:"Cloud"` +} + +// FallbackUploadAddress 备用上传地址 +type FallbackUploadAddress struct { + StoreInfos []StoreInfo `json:"StoreInfos"` + UploadHosts []string `json:"UploadHosts"` + UploadHeader map[string]interface{} `json:"UploadHeader"` + SessionKey string `json:"SessionKey"` + Cloud string `json:"Cloud"` +} + +// UploadNode 上传节点信息 +type UploadNode struct { + Vid string `json:"Vid"` + Vids []string `json:"Vids"` + StoreInfos []StoreInfo `json:"StoreInfos"` + UploadHost string `json:"UploadHost"` + UploadHeader map[string]interface{} `json:"UploadHeader"` + Type string `json:"Type"` + Protocol string `json:"Protocol"` + SessionKey string `json:"SessionKey"` + NodeConfig struct { + UploadMode string `json:"UploadMode"` + } `json:"NodeConfig"` + Cluster string `json:"Cluster"` +} + +// AdvanceOption 高级选项 +type AdvanceOption struct { + Parallel int `json:"Parallel"` + Stream int `json:"Stream"` + SliceSize int `json:"SliceSize"` + EncryptionKey string `json:"EncryptionKey"` +} + +// InnerUploadAddress 内部上传地址 +type InnerUploadAddress struct { + UploadNodes []UploadNode `json:"UploadNodes"` + AdvanceOption AdvanceOption `json:"AdvanceOption"` +} + +// UploadPart 上传分片信息 +type UploadPart struct { + UploadId string `json:"uploadid,omitempty"` + PartNumber string `json:"part_number,omitempty"` + Crc32 string `json:"crc32,omitempty"` + Etag string `json:"etag,omitempty"` + Mode string `json:"mode,omitempty"` +} + +// UploadResp 上传响应体 +type UploadResp struct { + Code int `json:"code"` + ApiVersion string `json:"apiversion"` + Message string `json:"message"` + Data UploadPart `json:"data"` +} + +type VideoCommitUpload struct { + Vid string `json:"Vid"` + VideoMeta struct { + URI string `json:"Uri"` + Height int `json:"Height"` + Width int `json:"Width"` + OriginHeight int `json:"OriginHeight"` + OriginWidth int `json:"OriginWidth"` + Duration float64 `json:"Duration"` + Bitrate int `json:"Bitrate"` + Md5 string `json:"Md5"` + Format string `json:"Format"` + Size int `json:"Size"` + FileType string `json:"FileType"` + Codec string `json:"Codec"` + } `json:"VideoMeta"` + WorkflowInput struct { + TemplateID string `json:"TemplateId"` + } `json:"WorkflowInput"` + GetPosterMode string `json:"GetPosterMode"` +} + +type VideoCommitUploadResp struct { + ResponseMetadata ResponseMetadata `json:"ResponseMetadata"` + Result struct { + RequestID string `json:"RequestId"` + Results []VideoCommitUpload `json:"Results"` + } `json:"Result"` +} + +type CommonResp struct { + Code int `json:"code"` + Msg string `json:"msg,omitempty"` + Message string `json:"message,omitempty"` // 错误情况下的消息 + Data json.RawMessage `json:"data,omitempty"` // 原始数据,稍后解析 + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + Locale string `json:"locale"` + } `json:"error,omitempty"` +} + +// IsSuccess 判断响应是否成功 +func (r *CommonResp) IsSuccess() bool { + return r.Code == 0 +} + +// GetError 获取错误信息 +func (r *CommonResp) GetError() error { + if r.IsSuccess() { + return nil + } + // 优先使用message字段 + errMsg := r.Message + if errMsg == "" { + errMsg = r.Msg + } + // 如果error对象存在且有详细消息,则使用error中的信息 + if r.Error != nil && r.Error.Message != "" { + errMsg = r.Error.Message + } + + return fmt.Errorf("[doubao] API error (code: %d): %s", r.Code, errMsg) +} + +// UnmarshalData 将data字段解析为指定类型 +func (r *CommonResp) UnmarshalData(v interface{}) error { + if !r.IsSuccess() { + return r.GetError() + } + + if len(r.Data) == 0 { + return nil + } + + return json.Unmarshal(r.Data, v) } diff --git a/drivers/doubao/util.go b/drivers/doubao/util.go index 977691c0..348c0aa0 100644 --- a/drivers/doubao/util.go +++ b/drivers/doubao/util.go @@ -1,38 +1,970 @@ package doubao import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" - + "fmt" "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/errgroup" "github.com/alist-org/alist/v3/pkg/utils" + "github.com/avast/retry-go" + "github.com/go-resty/resty/v2" + "github.com/google/uuid" log "github.com/sirupsen/logrus" + "hash/crc32" + "io" + "math" + "math/rand" + "net/http" + "net/url" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +const ( + DirectoryType = 1 + FileType = 2 + LinkType = 3 + ImageType = 4 + PagesType = 5 + VideoType = 6 + AudioType = 7 + MeetingMinutesType = 8 +) + +var FileNodeType = map[int]string{ + 1: "directory", + 2: "file", + 3: "link", + 4: "image", + 5: "pages", + 6: "video", + 7: "audio", + 8: "meeting_minutes", +} + +const ( + BaseURL = "https://www.doubao.com" + FileDataType = "file" + ImgDataType = "image" + VideoDataType = "video" + DefaultChunkSize = int64(5 * 1024 * 1024) // 5MB + MaxRetryAttempts = 3 // 最大重试次数 + UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36" + Region = "cn-north-1" + UploadTimeout = 3 * time.Minute ) // do others that not defined in Driver interface func (d *Doubao) request(path string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { - url := "https://www.doubao.com" + path + reqUrl := BaseURL + path req := base.RestyClient.R() req.SetHeader("Cookie", d.Cookie) if callback != nil { callback(req) } - var r BaseResp - req.SetResult(&r) - res, err := req.Execute(method, url) + + var commonResp CommonResp + + res, err := req.Execute(method, reqUrl) log.Debugln(res.String()) if err != nil { return nil, err } - // 业务状态码检查(优先于HTTP状态码) - if r.Code != 0 { - return res.Body(), errors.New(r.Msg) + body := res.Body() + // 先解析为通用响应 + if err = json.Unmarshal(body, &commonResp); err != nil { + return nil, err } + // 检查响应是否成功 + if !commonResp.IsSuccess() { + return body, commonResp.GetError() + } + if resp != nil { - err = utils.Json.Unmarshal(res.Body(), resp) + if err = json.Unmarshal(body, resp); err != nil { + return body, err + } + } + + return body, nil +} + +func (d *Doubao) getFiles(dirId, cursor string) (resp []File, err error) { + var r NodeInfoResp + + var body = base.Json{ + "node_id": dirId, + } + // 如果有游标,则设置游标和大小 + if cursor != "" { + body["cursor"] = cursor + body["size"] = 50 + } else { + body["need_full_path"] = false + } + + _, err = d.request("/samantha/aispace/node_info", http.MethodPost, func(req *resty.Request) { + req.SetBody(body) + }, &r) + if err != nil { + return nil, err + } + + if r.Data.Children != nil { + resp = r.Data.Children + } + + if r.Data.NextCursor != "-1" { + // 递归获取下一页 + nextFiles, err := d.getFiles(dirId, r.Data.NextCursor) if err != nil { return nil, err } + + resp = append(r.Data.Children, nextFiles...) } + + return resp, err +} + +func (d *Doubao) getUserInfo() (UserInfo, error) { + var r UserInfoResp + + _, err := d.request("/passport/account/info/v2/", http.MethodGet, nil, &r) + if err != nil { + return UserInfo{}, err + } + + return r.Data, err +} + +// 签名请求 +func (d *Doubao) signRequest(req *resty.Request, method, tokenType, uploadUrl string) error { + parsedUrl, err := url.Parse(uploadUrl) + if err != nil { + return fmt.Errorf("invalid URL format: %w", err) + } + + var accessKeyId, secretAccessKey, sessionToken string + var serviceName string + + if tokenType == VideoDataType { + accessKeyId = d.UploadToken.Samantha.StsToken.AccessKeyID + secretAccessKey = d.UploadToken.Samantha.StsToken.SecretAccessKey + sessionToken = d.UploadToken.Samantha.StsToken.SessionToken + serviceName = "vod" + } else { + accessKeyId = d.UploadToken.Alice[tokenType].Auth.AccessKeyID + secretAccessKey = d.UploadToken.Alice[tokenType].Auth.SecretAccessKey + sessionToken = d.UploadToken.Alice[tokenType].Auth.SessionToken + serviceName = "imagex" + } + + // 当前时间,格式为 ISO8601 + now := time.Now().UTC() + amzDate := now.Format("20060102T150405Z") + dateStamp := now.Format("20060102") + + req.SetHeader("X-Amz-Date", amzDate) + + if sessionToken != "" { + req.SetHeader("X-Amz-Security-Token", sessionToken) + } + + // 计算请求体的SHA256哈希 + var bodyHash string + if req.Body != nil { + bodyBytes, ok := req.Body.([]byte) + if !ok { + return fmt.Errorf("request body must be []byte") + } + + bodyHash = hashSHA256(string(bodyBytes)) + req.SetHeader("X-Amz-Content-Sha256", bodyHash) + } else { + bodyHash = hashSHA256("") + } + + // 创建规范请求 + canonicalURI := parsedUrl.Path + if canonicalURI == "" { + canonicalURI = "/" + } + + // 查询参数按照字母顺序排序 + canonicalQueryString := getCanonicalQueryString(req.QueryParam) + // 规范请求头 + canonicalHeaders, signedHeaders := getCanonicalHeadersFromMap(req.Header) + canonicalRequest := method + "\n" + + canonicalURI + "\n" + + canonicalQueryString + "\n" + + canonicalHeaders + "\n" + + signedHeaders + "\n" + + bodyHash + + algorithm := "AWS4-HMAC-SHA256" + credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, Region, serviceName) + + stringToSign := algorithm + "\n" + + amzDate + "\n" + + credentialScope + "\n" + + hashSHA256(canonicalRequest) + // 计算签名密钥 + signingKey := getSigningKey(secretAccessKey, dateStamp, Region, serviceName) + // 计算签名 + signature := hmacSHA256Hex(signingKey, stringToSign) + // 构建授权头 + authorizationHeader := fmt.Sprintf( + "%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", + algorithm, + accessKeyId, + credentialScope, + signedHeaders, + signature, + ) + + req.SetHeader("Authorization", authorizationHeader) + + return nil +} + +func (d *Doubao) requestApi(url, method, tokenType string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "user-agent": UserAgent, + }) + + if method == http.MethodPost { + req.SetHeader("Content-Type", "text/plain;charset=UTF-8") + } + + if callback != nil { + callback(req) + } + + if resp != nil { + req.SetResult(resp) + } + + // 使用自定义AWS SigV4签名 + err := d.signRequest(req, method, tokenType, url) + if err != nil { + return nil, err + } + + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + return res.Body(), nil } + +func (d *Doubao) initUploadToken() (*UploadToken, error) { + uploadToken := &UploadToken{ + Alice: make(map[string]UploadAuthToken), + Samantha: MediaUploadAuthToken{}, + } + + fileAuthToken, err := d.getUploadAuthToken(FileDataType) + if err != nil { + return nil, err + } + + imgAuthToken, err := d.getUploadAuthToken(ImgDataType) + if err != nil { + return nil, err + } + + mediaAuthToken, err := d.getSamantaUploadAuthToken() + if err != nil { + return nil, err + } + + uploadToken.Alice[FileDataType] = fileAuthToken + uploadToken.Alice[ImgDataType] = imgAuthToken + uploadToken.Samantha = mediaAuthToken + + return uploadToken, nil +} + +func (d *Doubao) getUploadAuthToken(dataType string) (ut UploadAuthToken, err error) { + var r UploadAuthTokenResp + _, err = d.request("/alice/upload/auth_token", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "scene": "bot_chat", + "data_type": dataType, + }) + }, &r) + + return r.Data, err +} + +func (d *Doubao) getSamantaUploadAuthToken() (mt MediaUploadAuthToken, err error) { + var r MediaUploadAuthTokenResp + _, err = d.request("/samantha/media/get_upload_token", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{}) + }, &r) + + return r.Data, err +} + +// getUploadConfig 获取上传配置信息 +func (d *Doubao) getUploadConfig(upConfig *UploadConfig, dataType string, file model.FileStreamer) error { + tokenType := dataType + // 配置参数函数 + configureParams := func() (string, map[string]string) { + var uploadUrl string + var params map[string]string + // 根据数据类型设置不同的上传参数 + switch dataType { + case VideoDataType: + // 音频/视频类型 - 使用uploadToken.Samantha的配置 + uploadUrl = d.UploadToken.Samantha.UploadInfo.VideoHost + params = map[string]string{ + "Action": "ApplyUploadInner", + "Version": "2020-11-19", + "SpaceName": d.UploadToken.Samantha.UploadInfo.SpaceName, + "FileType": "video", + "IsInner": "1", + "NeedFallback": "true", + "FileSize": strconv.FormatInt(file.GetSize(), 10), + "s": randomString(), + } + case ImgDataType, FileDataType: + // 图片或其他文件类型 - 使用uploadToken.Alice对应配置 + uploadUrl = "https://" + d.UploadToken.Alice[dataType].UploadHost + params = map[string]string{ + "Action": "ApplyImageUpload", + "Version": "2018-08-01", + "ServiceId": d.UploadToken.Alice[dataType].ServiceID, + "NeedFallback": "true", + "FileSize": strconv.FormatInt(file.GetSize(), 10), + "FileExtension": filepath.Ext(file.GetName()), + "s": randomString(), + } + } + return uploadUrl, params + } + + // 获取初始参数 + uploadUrl, params := configureParams() + + tokenRefreshed := false + var configResp UploadConfigResp + + err := d._retryOperation("get upload_config", func() error { + configResp = UploadConfigResp{} + + _, err := d.requestApi(uploadUrl, http.MethodGet, tokenType, func(req *resty.Request) { + req.SetQueryParams(params) + }, &configResp) + if err != nil { + return err + } + + if configResp.ResponseMetadata.Error.Code == "" { + *upConfig = configResp.Result + return nil + } + + // 100028 凭证过期 + if configResp.ResponseMetadata.Error.CodeN == 100028 && !tokenRefreshed { + log.Debugln("[doubao] Upload token expired, re-fetching...") + newToken, err := d.initUploadToken() + if err != nil { + return fmt.Errorf("failed to refresh token: %w", err) + } + + d.UploadToken = newToken + tokenRefreshed = true + uploadUrl, params = configureParams() + + return retry.Error{errors.New("token refreshed, retry needed")} + } + + return fmt.Errorf("get upload_config failed: %s", configResp.ResponseMetadata.Error.Message) + }) + + return err +} + +// uploadNode 上传 文件信息 +func (d *Doubao) uploadNode(uploadConfig *UploadConfig, dir model.Obj, file model.FileStreamer, dataType string) (UploadNodeResp, error) { + reqUuid := uuid.New().String() + var key string + var nodeType int + + mimetype := file.GetMimetype() + switch dataType { + case VideoDataType: + key = uploadConfig.InnerUploadAddress.UploadNodes[0].Vid + if strings.HasPrefix(mimetype, "audio/") { + nodeType = AudioType // 音频类型 + } else { + nodeType = VideoType // 视频类型 + } + case ImgDataType: + key = uploadConfig.InnerUploadAddress.UploadNodes[0].StoreInfos[0].StoreURI + nodeType = ImageType // 图片类型 + default: // FileDataType + key = uploadConfig.InnerUploadAddress.UploadNodes[0].StoreInfos[0].StoreURI + nodeType = FileType // 文件类型 + } + + var r UploadNodeResp + _, err := d.request("/samantha/aispace/upload_node", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "node_list": []base.Json{ + { + "local_id": reqUuid, + "parent_id": dir.GetID(), + "name": file.GetName(), + "key": key, + "node_content": base.Json{}, + "node_type": nodeType, + "size": file.GetSize(), + }, + }, + "request_id": reqUuid, + }) + }, &r) + + return r, err +} + +// Upload 普通上传实现 +func (d *Doubao) Upload(config *UploadConfig, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, dataType string) (model.Obj, error) { + data, err := io.ReadAll(file) + if err != nil { + return nil, err + } + + // 计算CRC32 + crc32Hash := crc32.NewIEEE() + crc32Hash.Write(data) + crc32Value := hex.EncodeToString(crc32Hash.Sum(nil)) + + // 构建请求路径 + uploadNode := config.InnerUploadAddress.UploadNodes[0] + storeInfo := uploadNode.StoreInfos[0] + uploadUrl := fmt.Sprintf("https://%s/upload/v1/%s", uploadNode.UploadHost, storeInfo.StoreURI) + + uploadResp := UploadResp{} + + if _, err = d.uploadRequest(uploadUrl, http.MethodPost, storeInfo, func(req *resty.Request) { + req.SetHeaders(map[string]string{ + "Content-Type": "application/octet-stream", + "Content-Crc32": crc32Value, + "Content-Length": fmt.Sprintf("%d", len(data)), + "Content-Disposition": fmt.Sprintf("attachment; filename=%s", url.QueryEscape(storeInfo.StoreURI)), + }) + + req.SetBody(data) + }, &uploadResp); err != nil { + return nil, err + } + + if uploadResp.Code != 2000 { + return nil, fmt.Errorf("upload failed: %s", uploadResp.Message) + } + + uploadNodeResp, err := d.uploadNode(config, dstDir, file, dataType) + if err != nil { + return nil, err + } + + return &model.Object{ + ID: uploadNodeResp.Data.NodeList[0].ID, + Name: uploadNodeResp.Data.NodeList[0].Name, + Size: file.GetSize(), + IsFolder: false, + }, nil +} + +// UploadByMultipart 分片上传 +func (d *Doubao) UploadByMultipart(ctx context.Context, config *UploadConfig, fileSize int64, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, dataType string) (model.Obj, error) { + // 构建请求路径 + uploadNode := config.InnerUploadAddress.UploadNodes[0] + storeInfo := uploadNode.StoreInfos[0] + uploadUrl := fmt.Sprintf("https://%s/upload/v1/%s", uploadNode.UploadHost, storeInfo.StoreURI) + // 初始化分片上传 + var uploadID string + err := d._retryOperation("Initialize multipart upload", func() error { + var err error + uploadID, err = d.initMultipartUpload(config, uploadUrl, storeInfo) + return err + }) + if err != nil { + return nil, fmt.Errorf("failed to initialize multipart upload: %w", err) + } + // 准备分片参数 + chunkSize := DefaultChunkSize + if config.InnerUploadAddress.AdvanceOption.SliceSize > 0 { + chunkSize = int64(config.InnerUploadAddress.AdvanceOption.SliceSize) + } + totalParts := (fileSize + chunkSize - 1) / chunkSize + // 创建分片信息组 + parts := make([]UploadPart, totalParts) + // 缓存文件 + tempFile, err := file.CacheFullInTempFile() + if err != nil { + return nil, fmt.Errorf("failed to cache file: %w", err) + } + defer tempFile.Close() + up(10.0) // 更新进度 + // 设置并行上传 + threadG, uploadCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, + retry.Attempts(1), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + + var partsMutex sync.Mutex + // 并行上传所有分片 + for partIndex := int64(0); partIndex < totalParts; partIndex++ { + if utils.IsCanceled(uploadCtx) { + break + } + partIndex := partIndex + partNumber := partIndex + 1 // 分片编号从1开始 + + threadG.Go(func(ctx context.Context) error { + // 计算此分片的大小和偏移 + offset := partIndex * chunkSize + size := chunkSize + if partIndex == totalParts-1 { + size = fileSize - offset + } + + limitedReader := driver.NewLimitedUploadStream(ctx, io.NewSectionReader(tempFile, offset, size)) + // 读取数据到内存 + data, err := io.ReadAll(limitedReader) + if err != nil { + return fmt.Errorf("failed to read part %d: %w", partNumber, err) + } + // 计算CRC32 + crc32Value := calculateCRC32(data) + // 使用_retryOperation上传分片 + var uploadPart UploadPart + if err = d._retryOperation(fmt.Sprintf("Upload part %d", partNumber), func() error { + var err error + uploadPart, err = d.uploadPart(config, uploadUrl, uploadID, partNumber, data, crc32Value) + return err + }); err != nil { + return fmt.Errorf("part %d upload failed: %w", partNumber, err) + } + // 记录成功上传的分片 + partsMutex.Lock() + parts[partIndex] = UploadPart{ + PartNumber: strconv.FormatInt(partNumber, 10), + Etag: uploadPart.Etag, + Crc32: crc32Value, + } + partsMutex.Unlock() + // 更新进度 + progress := 10.0 + 90.0*float64(threadG.Success()+1)/float64(totalParts) + up(math.Min(progress, 95.0)) + + return nil + }) + } + + if err = threadG.Wait(); err != nil { + return nil, err + } + // 完成上传-分片合并 + if err = d._retryOperation("Complete multipart upload", func() error { + return d.completeMultipartUpload(config, uploadUrl, uploadID, parts) + }); err != nil { + return nil, fmt.Errorf("failed to complete multipart upload: %w", err) + } + // 提交上传 + if err = d._retryOperation("Commit upload", func() error { + return d.commitMultipartUpload(config) + }); err != nil { + return nil, fmt.Errorf("failed to commit upload: %w", err) + } + + up(98.0) // 更新到98% + // 上传节点信息 + var uploadNodeResp UploadNodeResp + + if err = d._retryOperation("Upload node", func() error { + var err error + uploadNodeResp, err = d.uploadNode(config, dstDir, file, dataType) + return err + }); err != nil { + return nil, fmt.Errorf("failed to upload node: %w", err) + } + + up(100.0) // 完成上传 + + return &model.Object{ + ID: uploadNodeResp.Data.NodeList[0].ID, + Name: uploadNodeResp.Data.NodeList[0].Name, + Size: file.GetSize(), + IsFolder: false, + }, nil +} + +// 统一上传请求方法 +func (d *Doubao) uploadRequest(uploadUrl string, method string, storeInfo StoreInfo, callback base.ReqCallback, resp interface{}) ([]byte, error) { + client := resty.New() + client.SetTransport(&http.Transport{ + DisableKeepAlives: true, // 禁用连接复用 + ForceAttemptHTTP2: false, // 强制使用HTTP/1.1 + }) + client.SetTimeout(UploadTimeout) + + req := client.R() + req.SetHeaders(map[string]string{ + "Host": strings.Split(uploadUrl, "/")[2], + "Referer": BaseURL + "/", + "Origin": BaseURL, + "User-Agent": UserAgent, + "X-Storage-U": d.UserId, + "Authorization": storeInfo.Auth, + }) + + if method == http.MethodPost { + req.SetHeader("Content-Type", "text/plain;charset=UTF-8") + } + + if callback != nil { + callback(req) + } + + if resp != nil { + req.SetResult(resp) + } + + res, err := req.Execute(method, uploadUrl) + if err != nil && err != io.EOF { + return nil, fmt.Errorf("upload request failed: %w", err) + } + + return res.Body(), nil +} + +// 初始化分片上传 +func (d *Doubao) initMultipartUpload(config *UploadConfig, uploadUrl string, storeInfo StoreInfo) (uploadId string, err error) { + uploadResp := UploadResp{} + + _, err = d.uploadRequest(uploadUrl, http.MethodPost, storeInfo, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "uploadmode": "part", + "phase": "init", + }) + }, &uploadResp) + + if err != nil { + return uploadId, err + } + + if uploadResp.Code != 2000 { + return uploadId, fmt.Errorf("init upload failed: %s", uploadResp.Message) + } + + return uploadResp.Data.UploadId, nil +} + +// 分片上传实现 +func (d *Doubao) uploadPart(config *UploadConfig, uploadUrl, uploadID string, partNumber int64, data []byte, crc32Value string) (resp UploadPart, err error) { + uploadResp := UploadResp{} + storeInfo := config.InnerUploadAddress.UploadNodes[0].StoreInfos[0] + + _, err = d.uploadRequest(uploadUrl, http.MethodPost, storeInfo, func(req *resty.Request) { + req.SetHeaders(map[string]string{ + "Content-Type": "application/octet-stream", + "Content-Crc32": crc32Value, + "Content-Length": fmt.Sprintf("%d", len(data)), + "Content-Disposition": fmt.Sprintf("attachment; filename=%s", url.QueryEscape(storeInfo.StoreURI)), + }) + + req.SetQueryParams(map[string]string{ + "uploadid": uploadID, + "part_number": strconv.FormatInt(partNumber, 10), + "phase": "transfer", + }) + + req.SetBody(data) + req.SetContentLength(true) + }, &uploadResp) + + if err != nil { + return resp, err + } + + if uploadResp.Code != 2000 { + return resp, fmt.Errorf("upload part failed: %s", uploadResp.Message) + } else if uploadResp.Data.Crc32 != crc32Value { + return resp, fmt.Errorf("upload part failed: crc32 mismatch, expected %s, got %s", crc32Value, uploadResp.Data.Crc32) + } + + return uploadResp.Data, nil +} + +// 完成分片上传 +func (d *Doubao) completeMultipartUpload(config *UploadConfig, uploadUrl, uploadID string, parts []UploadPart) error { + uploadResp := UploadResp{} + + storeInfo := config.InnerUploadAddress.UploadNodes[0].StoreInfos[0] + + body := _convertUploadParts(parts) + + err := utils.Retry(MaxRetryAttempts, time.Second, func() (err error) { + _, err = d.uploadRequest(uploadUrl, http.MethodPost, storeInfo, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "uploadid": uploadID, + "phase": "finish", + "uploadmode": "part", + }) + req.SetBody(body) + }, &uploadResp) + + if err != nil { + return err + } + // 检查响应状态码 2000 成功 4024 分片合并中 + if uploadResp.Code != 2000 && uploadResp.Code != 4024 { + return fmt.Errorf("finish upload failed: %s", uploadResp.Message) + } + + return err + }) + + if err != nil { + return fmt.Errorf("failed to complete multipart upload: %w", err) + } + + return nil +} + +func (d *Doubao) commitMultipartUpload(uploadConfig *UploadConfig) error { + uploadUrl := d.UploadToken.Samantha.UploadInfo.VideoHost + params := map[string]string{ + "Action": "CommitUploadInner", + "Version": "2020-11-19", + "SpaceName": d.UploadToken.Samantha.UploadInfo.SpaceName, + } + tokenType := VideoDataType + + videoCommitUploadResp := VideoCommitUploadResp{} + + jsonBytes, err := json.Marshal(base.Json{ + "SessionKey": uploadConfig.InnerUploadAddress.UploadNodes[0].SessionKey, + "Functions": []base.Json{}, + }) + if err != nil { + return fmt.Errorf("failed to marshal request data: %w", err) + } + + _, err = d.requestApi(uploadUrl, http.MethodPost, tokenType, func(req *resty.Request) { + req.SetHeader("Content-Type", "application/json") + req.SetQueryParams(params) + req.SetBody(jsonBytes) + + }, &videoCommitUploadResp) + if err != nil { + return err + } + + return nil +} + +// 计算CRC32 +func calculateCRC32(data []byte) string { + hash := crc32.NewIEEE() + hash.Write(data) + return hex.EncodeToString(hash.Sum(nil)) +} + +// _retryOperation 操作重试 +func (d *Doubao) _retryOperation(operation string, fn func() error) error { + return retry.Do( + fn, + retry.Attempts(MaxRetryAttempts), + retry.Delay(500*time.Millisecond), + retry.DelayType(retry.BackOffDelay), + retry.MaxJitter(200*time.Millisecond), + retry.OnRetry(func(n uint, err error) { + log.Debugf("[doubao] %s retry #%d: %v", operation, n+1, err) + }), + ) +} + +// _convertUploadParts 将分片信息转换为字符串 +func _convertUploadParts(parts []UploadPart) string { + if len(parts) == 0 { + return "" + } + + var result strings.Builder + + for i, part := range parts { + if i > 0 { + result.WriteString(",") + } + result.WriteString(fmt.Sprintf("%s:%s", part.PartNumber, part.Crc32)) + } + + return result.String() +} + +// 获取规范查询字符串 +func getCanonicalQueryString(query url.Values) string { + if len(query) == 0 { + return "" + } + + keys := make([]string, 0, len(query)) + for k := range query { + keys = append(keys, k) + } + sort.Strings(keys) + + parts := make([]string, 0, len(keys)) + for _, k := range keys { + values := query[k] + for _, v := range values { + parts = append(parts, urlEncode(k)+"="+urlEncode(v)) + } + } + + return strings.Join(parts, "&") +} + +func urlEncode(s string) string { + s = url.QueryEscape(s) + s = strings.ReplaceAll(s, "+", "%20") + return s +} + +// 获取规范头信息和已签名头列表 +func getCanonicalHeadersFromMap(headers map[string][]string) (string, string) { + // 不可签名的头部列表 + unsignableHeaders := map[string]bool{ + "authorization": true, + "content-type": true, + "content-length": true, + "user-agent": true, + "presigned-expires": true, + "expect": true, + "x-amzn-trace-id": true, + } + headerValues := make(map[string]string) + var signedHeadersList []string + + for k, v := range headers { + if len(v) == 0 { + continue + } + + lowerKey := strings.ToLower(k) + // 检查是否可签名 + if strings.HasPrefix(lowerKey, "x-amz-") || !unsignableHeaders[lowerKey] { + value := strings.TrimSpace(v[0]) + value = strings.Join(strings.Fields(value), " ") + headerValues[lowerKey] = value + signedHeadersList = append(signedHeadersList, lowerKey) + } + } + + sort.Strings(signedHeadersList) + + var canonicalHeadersStr strings.Builder + for _, key := range signedHeadersList { + canonicalHeadersStr.WriteString(key) + canonicalHeadersStr.WriteString(":") + canonicalHeadersStr.WriteString(headerValues[key]) + canonicalHeadersStr.WriteString("\n") + } + + signedHeaders := strings.Join(signedHeadersList, ";") + + return canonicalHeadersStr.String(), signedHeaders +} + +// 计算HMAC-SHA256 +func hmacSHA256(key []byte, data string) []byte { + h := hmac.New(sha256.New, key) + h.Write([]byte(data)) + return h.Sum(nil) +} + +// 计算HMAC-SHA256并返回十六进制字符串 +func hmacSHA256Hex(key []byte, data string) string { + return hex.EncodeToString(hmacSHA256(key, data)) +} + +// 计算SHA256哈希并返回十六进制字符串 +func hashSHA256(data string) string { + h := sha256.New() + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) +} + +// 获取签名密钥 +func getSigningKey(secretKey, dateStamp, region, service string) []byte { + kDate := hmacSHA256([]byte("AWS4"+secretKey), dateStamp) + kRegion := hmacSHA256(kDate, region) + kService := hmacSHA256(kRegion, service) + kSigning := hmacSHA256(kService, "aws4_request") + return kSigning +} + +// generateContentDisposition 生成符合RFC 5987标准的Content-Disposition头部 +func generateContentDisposition(filename string) string { + // 按照RFC 2047进行编码,用于filename部分 + encodedName := urlEncode(filename) + + // 按照RFC 5987进行编码,用于filename*部分 + encodedNameRFC5987 := encodeRFC5987(filename) + + return fmt.Sprintf("attachment; filename=\"%s\"; filename*=utf-8''%s", + encodedName, encodedNameRFC5987) +} + +// encodeRFC5987 按照RFC 5987规范编码字符串,适用于HTTP头部参数中的非ASCII字符 +func encodeRFC5987(s string) string { + var buf strings.Builder + for _, r := range []byte(s) { + // 根据RFC 5987,只有字母、数字和部分特殊符号可以不编码 + if (r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') || + r == '-' || r == '.' || r == '_' || r == '~' { + buf.WriteByte(r) + } else { + // 其他字符都需要百分号编码 + fmt.Fprintf(&buf, "%%%02X", r) + } + } + return buf.String() +} + +func randomString() string { + const charset = "0123456789abcdefghijklmnopqrstuvwxyz" + const length = 11 // 11位随机字符串 + + var sb strings.Builder + sb.Grow(length) + + for i := 0; i < length; i++ { + sb.WriteByte(charset[rand.Intn(len(charset))]) + } + + return sb.String() +} diff --git a/drivers/github/util.go b/drivers/github/util.go index 03318784..7ddf8746 100644 --- a/drivers/github/util.go +++ b/drivers/github/util.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "io" "strings" "text/template" "time" @@ -159,7 +158,7 @@ func signCommit(m *map[string]interface{}, entity *openpgp.Entity) (string, erro if err != nil { return "", err } - if _, err = io.Copy(armorWriter, &sigBuffer); err != nil { + if _, err = utils.CopyWithBuffer(armorWriter, &sigBuffer); err != nil { return "", err } _ = armorWriter.Close() diff --git a/drivers/ilanzou/driver.go b/drivers/ilanzou/driver.go index 39a311dd..044193d3 100644 --- a/drivers/ilanzou/driver.go +++ b/drivers/ilanzou/driver.go @@ -2,7 +2,6 @@ package template import ( "context" - "crypto/md5" "encoding/base64" "encoding/hex" "fmt" @@ -17,6 +16,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/foxxorcat/mopan-sdk-go" "github.com/go-resty/resty/v2" @@ -273,23 +273,14 @@ func (d *ILanZou) Remove(ctx context.Context, obj model.Obj) error { const DefaultPartSize = 1024 * 1024 * 8 func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { - h := md5.New() - // need to calculate md5 of the full content - tempFile, err := s.CacheFullInTempFile() - if err != nil { - return nil, err + etag := s.GetHash().GetHash(utils.MD5) + var err error + if len(etag) != utils.MD5.Width { + _, etag, err = stream.CacheFullInTempFileAndHash(s, utils.MD5) + if err != nil { + return nil, err + } } - defer func() { - _ = tempFile.Close() - }() - if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { - return nil, err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return nil, err - } - etag := hex.EncodeToString(h.Sum(nil)) // get upToken res, err := d.proved("/7n/getUpToken", http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{ @@ -309,7 +300,7 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreame key := fmt.Sprintf("disk/%d/%d/%d/%s/%016d", now.Year(), now.Month(), now.Day(), d.account, now.UnixMilli()) reader := driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: &driver.SimpleReaderWithSize{ - Reader: tempFile, + Reader: s, Size: s.GetSize(), }, UpdateProgress: up, diff --git a/drivers/ipfs_api/driver.go b/drivers/ipfs_api/driver.go index e59da7ca..264cef28 100644 --- a/drivers/ipfs_api/driver.go +++ b/drivers/ipfs_api/driver.go @@ -4,8 +4,7 @@ import ( "context" "fmt" "net/url" - "path/filepath" - "strings" + "path" shell "github.com/ipfs/go-ipfs-api" @@ -43,78 +42,115 @@ func (d *IPFS) Drop(ctx context.Context) error { } func (d *IPFS) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { - path := dir.GetPath() - switch d.Mode { - case "ipfs": - path, _ = url.JoinPath("/ipfs", path) - case "ipns": - path, _ = url.JoinPath("/ipns", path) - case "mfs": - fileStat, err := d.sh.FilesStat(ctx, path) - if err != nil { - return nil, err + var ipfsPath string + cid := dir.GetID() + if cid != "" { + ipfsPath = path.Join("/ipfs", cid) + } else { + // 可能出现ipns dns解析失败的情况,需要重复获取cid,其他情况应该不会出错 + ipfsPath = dir.GetPath() + switch d.Mode { + case "ipfs": + ipfsPath = path.Join("/ipfs", ipfsPath) + case "ipns": + ipfsPath = path.Join("/ipns", ipfsPath) + case "mfs": + fileStat, err := d.sh.FilesStat(ctx, ipfsPath) + if err != nil { + return nil, err + } + ipfsPath = path.Join("/ipfs", fileStat.Hash) + default: + return nil, fmt.Errorf("mode error") } - path, _ = url.JoinPath("/ipfs", fileStat.Hash) - default: - return nil, fmt.Errorf("mode error") } - - dirs, err := d.sh.List(path) + dirs, err := d.sh.List(ipfsPath) if err != nil { return nil, err } objlist := []model.Obj{} for _, file := range dirs { - gateurl := *d.gateURL.JoinPath("/ipfs/" + file.Hash) - gateurl.RawQuery = "filename=" + url.PathEscape(file.Name) - objlist = append(objlist, &model.ObjectURL{ - Object: model.Object{ID: "/ipfs/" + file.Hash, Name: file.Name, Size: int64(file.Size), IsFolder: file.Type == 1}, - Url: model.Url{Url: gateurl.String()}, - }) + objlist = append(objlist, &model.Object{ID: file.Hash, Name: file.Name, Size: int64(file.Size), IsFolder: file.Type == 1}) } return objlist, nil } func (d *IPFS) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { - gateurl := d.gateURL.JoinPath(file.GetID()) - gateurl.RawQuery = "filename=" + url.PathEscape(file.GetName()) + gateurl := d.gateURL.JoinPath("/ipfs/", file.GetID()) + gateurl.RawQuery = "filename=" + url.QueryEscape(file.GetName()) return &model.Link{URL: gateurl.String()}, nil } -func (d *IPFS) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { - if d.Mode != "mfs" { - return fmt.Errorf("only write in mfs mode") +func (d *IPFS) Get(ctx context.Context, rawPath string) (model.Obj, error) { + rawPath = path.Join(d.GetRootPath(), rawPath) + var ipfsPath string + switch d.Mode { + case "ipfs": + ipfsPath = path.Join("/ipfs", rawPath) + case "ipns": + ipfsPath = path.Join("/ipns", rawPath) + case "mfs": + fileStat, err := d.sh.FilesStat(ctx, rawPath) + if err != nil { + return nil, err + } + ipfsPath = path.Join("/ipfs", fileStat.Hash) + default: + return nil, fmt.Errorf("mode error") } - path := parentDir.GetPath() - if path[len(path):] != "/" { - path += "/" + file, err := d.sh.FilesStat(ctx, ipfsPath) + if err != nil { + return nil, err } - return d.sh.FilesMkdir(ctx, path+dirName) + return &model.Object{ID: file.Hash, Name: path.Base(rawPath), Path: rawPath, Size: int64(file.Size), IsFolder: file.Type == "directory"}, nil } -func (d *IPFS) Move(ctx context.Context, srcObj, dstDir model.Obj) error { +func (d *IPFS) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { if d.Mode != "mfs" { - return fmt.Errorf("only write in mfs mode") + return nil, fmt.Errorf("only write in mfs mode") } - return d.sh.FilesMv(ctx, srcObj.GetPath(), dstDir.GetPath()) + dirPath := parentDir.GetPath() + err := d.sh.FilesMkdir(ctx, path.Join(dirPath, dirName), shell.FilesMkdir.Parents(true)) + if err != nil { + return nil, err + } + file, err := d.sh.FilesStat(ctx, path.Join(dirPath, dirName)) + if err != nil { + return nil, err + } + return &model.Object{ID: file.Hash, Name: dirName, Path: path.Join(dirPath, dirName), Size: int64(file.Size), IsFolder: true}, nil } -func (d *IPFS) Rename(ctx context.Context, srcObj model.Obj, newName string) error { +func (d *IPFS) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { if d.Mode != "mfs" { - return fmt.Errorf("only write in mfs mode") + return nil, fmt.Errorf("only write in mfs mode") } - newFileName := filepath.Dir(srcObj.GetPath()) + "/" + newName - return d.sh.FilesMv(ctx, srcObj.GetPath(), strings.ReplaceAll(newFileName, "\\", "/")) + dstPath := path.Join(dstDir.GetPath(), path.Base(srcObj.GetPath())) + d.sh.FilesRm(ctx, dstPath, true) + return &model.Object{ID: srcObj.GetID(), Name: srcObj.GetName(), Path: dstPath, Size: int64(srcObj.GetSize()), IsFolder: srcObj.IsDir()}, + d.sh.FilesMv(ctx, srcObj.GetPath(), dstDir.GetPath()) } -func (d *IPFS) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { +func (d *IPFS) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { if d.Mode != "mfs" { - return fmt.Errorf("only write in mfs mode") + return nil, fmt.Errorf("only write in mfs mode") } - newFileName := dstDir.GetPath() + "/" + filepath.Base(srcObj.GetPath()) - return d.sh.FilesCp(ctx, srcObj.GetPath(), strings.ReplaceAll(newFileName, "\\", "/")) + dstPath := path.Join(path.Dir(srcObj.GetPath()), newName) + d.sh.FilesRm(ctx, dstPath, true) + return &model.Object{ID: srcObj.GetID(), Name: newName, Path: dstPath, Size: int64(srcObj.GetSize()), + IsFolder: srcObj.IsDir()}, d.sh.FilesMv(ctx, srcObj.GetPath(), dstPath) +} + +func (d *IPFS) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + if d.Mode != "mfs" { + return nil, fmt.Errorf("only write in mfs mode") + } + dstPath := path.Join(dstDir.GetPath(), path.Base(srcObj.GetPath())) + d.sh.FilesRm(ctx, dstPath, true) + return &model.Object{ID: srcObj.GetID(), Name: srcObj.GetName(), Path: dstPath, Size: int64(srcObj.GetSize()), IsFolder: srcObj.IsDir()}, + d.sh.FilesCp(ctx, path.Join("/ipfs/", srcObj.GetID()), dstPath, shell.FilesCp.Parents(true)) } func (d *IPFS) Remove(ctx context.Context, obj model.Obj) error { @@ -124,19 +160,25 @@ func (d *IPFS) Remove(ctx context.Context, obj model.Obj) error { return d.sh.FilesRm(ctx, obj.GetPath(), true) } -func (d *IPFS) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { +func (d *IPFS) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { if d.Mode != "mfs" { - return fmt.Errorf("only write in mfs mode") + return nil, fmt.Errorf("only write in mfs mode") } outHash, err := d.sh.Add(driver.NewLimitedUploadStream(ctx, &driver.ReaderUpdatingProgress{ Reader: s, UpdateProgress: up, })) if err != nil { - return err + return nil, err } - err = d.sh.FilesCp(ctx, "/ipfs/"+outHash, dstDir.GetPath()+"/"+strings.ReplaceAll(s.GetName(), "\\", "/")) - return err + dstPath := path.Join(dstDir.GetPath(), s.GetName()) + if s.GetExist() != nil { + d.sh.FilesRm(ctx, dstPath, true) + } + err = d.sh.FilesCp(ctx, path.Join("/ipfs/", outHash), dstPath, shell.FilesCp.Parents(true)) + gateurl := d.gateURL.JoinPath("/ipfs/", outHash) + gateurl.RawQuery = "filename=" + url.QueryEscape(s.GetName()) + return &model.Object{ID: outHash, Name: s.GetName(), Path: dstPath, Size: int64(s.GetSize()), IsFolder: s.IsDir()}, err } //func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { diff --git a/drivers/ipfs_api/meta.go b/drivers/ipfs_api/meta.go index c145644c..3837bec2 100644 --- a/drivers/ipfs_api/meta.go +++ b/drivers/ipfs_api/meta.go @@ -9,8 +9,8 @@ type Addition struct { // Usually one of two driver.RootPath Mode string `json:"mode" options:"ipfs,ipns,mfs" type:"select" required:"true"` - Endpoint string `json:"endpoint" default:"http://127.0.0.1:5001"` - Gateway string `json:"gateway" default:"http://127.0.0.1:8080"` + Endpoint string `json:"endpoint" default:"http://127.0.0.1:5001" required:"true"` + Gateway string `json:"gateway" default:"http://127.0.0.1:8080" required:"true"` } var config = driver.Config{ diff --git a/drivers/mopan/driver.go b/drivers/mopan/driver.go index 736d612a..f8f14300 100644 --- a/drivers/mopan/driver.go +++ b/drivers/mopan/driver.go @@ -269,9 +269,6 @@ func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre if err != nil { return nil, err } - defer func() { - _ = file.Close() - }() // step.1 uploadPartData, err := mopan.InitUploadPartData(ctx, mopan.UpdloadFileParam{ diff --git a/drivers/netease_music/util.go b/drivers/netease_music/util.go index 2e78be14..21718106 100644 --- a/drivers/netease_music/util.go +++ b/drivers/netease_music/util.go @@ -227,7 +227,6 @@ func (d *NeteaseMusic) putSongStream(ctx context.Context, stream model.FileStrea if err != nil { return err } - defer tmp.Close() u := uploader{driver: d, file: tmp} diff --git a/drivers/onedrive/util.go b/drivers/onedrive/util.go index 55434967..e256b7ae 100644 --- a/drivers/onedrive/util.go +++ b/drivers/onedrive/util.go @@ -220,7 +220,7 @@ func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.Fil if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } diff --git a/drivers/onedrive_app/util.go b/drivers/onedrive_app/util.go index 1b01324e..5c3b6c92 100644 --- a/drivers/onedrive_app/util.go +++ b/drivers/onedrive_app/util.go @@ -170,7 +170,7 @@ func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model. if err != nil { return err } - req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(byteData))) + req, err := http.NewRequest("PUT", uploadUrl, driver.NewLimitedUploadStream(ctx, bytes.NewReader(byteData))) if err != nil { return err } diff --git a/drivers/pikpak/util.go b/drivers/pikpak/util.go index 61396aa4..4cb3fbc3 100644 --- a/drivers/pikpak/util.go +++ b/drivers/pikpak/util.go @@ -7,13 +7,6 @@ import ( "crypto/sha1" "encoding/hex" "fmt" - "github.com/alist-org/alist/v3/internal/driver" - "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/internal/op" - "github.com/alist-org/alist/v3/pkg/utils" - "github.com/aliyun/aliyun-oss-go-sdk/oss" - jsoniter "github.com/json-iterator/go" - "github.com/pkg/errors" "io" "net/http" "path/filepath" @@ -24,7 +17,14 @@ import ( "time" "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + "github.com/pkg/errors" ) var AndroidAlgorithms = []string{ @@ -84,7 +84,7 @@ const ( WebClientID = "YUMx5nI8ZU8Ap8pm" WebClientSecret = "dbw2OtmVEeuUvIptb1Coyg" WebClientVersion = "2.0.0" - WebPackageName = "drive.mypikpak.com" + WebPackageName = "mypikpak.com" WebSdkVersion = "8.0.3" PCClientID = "YvtoWO6GNHiuCl7x" PCClientSecret = "1NIH5R1IEe2pAxZE3hv3uA" @@ -516,7 +516,7 @@ func (d *PikPak) UploadByMultipart(ctx context.Context, params *S3Params, fileSi continue } - b := driver.NewLimitedUploadStream(ctx, bytes.NewBuffer(buf)) + b := driver.NewLimitedUploadStream(ctx, bytes.NewReader(buf)) if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil { break } diff --git a/drivers/pikpak_share/util.go b/drivers/pikpak_share/util.go index 4111779f..2980069e 100644 --- a/drivers/pikpak_share/util.go +++ b/drivers/pikpak_share/util.go @@ -67,7 +67,7 @@ const ( WebClientID = "YUMx5nI8ZU8Ap8pm" WebClientSecret = "dbw2OtmVEeuUvIptb1Coyg" WebClientVersion = "2.0.0" - WebPackageName = "drive.mypikpak.com" + WebPackageName = "mypikpak.com" WebSdkVersion = "8.0.3" PCClientID = "YvtoWO6GNHiuCl7x" PCClientSecret = "1NIH5R1IEe2pAxZE3hv3uA" diff --git a/drivers/quark_uc/driver.go b/drivers/quark_uc/driver.go index 0f8884fa..7f497494 100644 --- a/drivers/quark_uc/driver.go +++ b/drivers/quark_uc/driver.go @@ -3,9 +3,8 @@ package quark import ( "bytes" "context" - "crypto/md5" - "crypto/sha1" "encoding/hex" + "hash" "io" "net/http" "time" @@ -14,6 +13,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + streamPkg "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" @@ -136,33 +136,33 @@ func (d *QuarkOrUC) Remove(ctx context.Context, obj model.Obj) error { } func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - tempFile, err := stream.CacheFullInTempFile() - if err != nil { - return err + md5Str, sha1Str := stream.GetHash().GetHash(utils.MD5), stream.GetHash().GetHash(utils.SHA1) + var ( + md5 hash.Hash + sha1 hash.Hash + ) + writers := []io.Writer{} + if len(md5Str) != utils.MD5.Width { + md5 = utils.MD5.NewFunc() + writers = append(writers, md5) } - defer func() { - _ = tempFile.Close() - }() - m := md5.New() - _, err = utils.CopyWithBuffer(m, tempFile) - if err != nil { - return err + if len(sha1Str) != utils.SHA1.Width { + sha1 = utils.SHA1.NewFunc() + writers = append(writers, sha1) } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err + + if len(writers) > 0 { + _, err := streamPkg.CacheFullInTempFileAndWriter(stream, io.MultiWriter(writers...)) + if err != nil { + return err + } + if md5 != nil { + md5Str = hex.EncodeToString(md5.Sum(nil)) + } + if sha1 != nil { + sha1Str = hex.EncodeToString(sha1.Sum(nil)) + } } - md5Str := hex.EncodeToString(m.Sum(nil)) - s := sha1.New() - _, err = utils.CopyWithBuffer(s, tempFile) - if err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } - sha1Str := hex.EncodeToString(s.Sum(nil)) // pre pre, err := d.upPre(stream, dstDir.GetID()) if err != nil { @@ -178,27 +178,28 @@ func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.File return nil } // part up - partSize := pre.Metadata.PartSize - var part []byte - md5s := make([]string, 0) - defaultBytes := make([]byte, partSize) total := stream.GetSize() left := total + partSize := int64(pre.Metadata.PartSize) + part := make([]byte, partSize) + count := int(total / partSize) + if total%partSize > 0 { + count++ + } + md5s := make([]string, 0, count) partNumber := 1 for left > 0 { if utils.IsCanceled(ctx) { return ctx.Err() } - if left > int64(partSize) { - part = defaultBytes - } else { - part = make([]byte, left) + if left < partSize { + part = part[:left] } - _, err := io.ReadFull(tempFile, part) + n, err := io.ReadFull(stream, part) if err != nil { return err } - left -= int64(len(part)) + left -= int64(n) log.Debugf("left: %d", left) reader := driver.NewLimitedUploadStream(ctx, bytes.NewReader(part)) m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, reader) diff --git a/drivers/thunder/driver.go b/drivers/thunder/driver.go index 7f41d003..1d2f2a81 100644 --- a/drivers/thunder/driver.go +++ b/drivers/thunder/driver.go @@ -12,6 +12,7 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" "github.com/aws/aws-sdk-go/aws" @@ -44,26 +45,29 @@ func (x *Thunder) Init(ctx context.Context) (err error) { Common: &Common{ client: base.NewRestyClient(), Algorithms: []string{ - "HPxr4BVygTQVtQkIMwQH33ywbgYG5l4JoR", - "GzhNkZ8pOBsCY+7", - "v+l0ImTpG7c7/", - "e5ztohgVXNP", - "t", - "EbXUWyVVqQbQX39Mbjn2geok3/0WEkAVxeqhtx857++kjJiRheP8l77gO", - "o7dvYgbRMOpHXxCs", - "6MW8TD8DphmakaxCqVrfv7NReRRN7ck3KLnXBculD58MvxjFRqT+", - "kmo0HxCKVfmxoZswLB4bVA/dwqbVAYghSb", - "j", - "4scKJNdd7F27Hv7tbt", + "9uJNVj/wLmdwKrJaVj/omlQ", + "Oz64Lp0GigmChHMf/6TNfxx7O9PyopcczMsnf", + "Eb+L7Ce+Ej48u", + "jKY0", + "ASr0zCl6v8W4aidjPK5KHd1Lq3t+vBFf41dqv5+fnOd", + "wQlozdg6r1qxh0eRmt3QgNXOvSZO6q/GXK", + "gmirk+ciAvIgA/cxUUCema47jr/YToixTT+Q6O", + "5IiCoM9B1/788ntB", + "P07JH0h6qoM6TSUAK2aL9T5s2QBVeY9JWvalf", + "+oK0AN", }, - DeviceID: utils.GetMD5EncodeStr(x.Username + x.Password), + DeviceID: func() string { + if len(x.DeviceID) != 32 { + return utils.GetMD5EncodeStr(x.DeviceID) + } + return x.DeviceID + }(), ClientID: "Xp6vsxz_7IYVw2BB", ClientSecret: "Xp6vsy4tN9toTVdMSpomVdXpRmES", - ClientVersion: "7.51.0.8196", + ClientVersion: "8.31.0.9726", PackageName: "com.xunlei.downloadprovider", - UserAgent: "ANDROID-com.xunlei.downloadprovider/7.51.0.8196 netWorkType/5G appid/40 deviceName/Xiaomi_M2004j7ac deviceModel/M2004J7AC OSVersion/12 protocolVersion/301 platformVersion/10 sdkVersion/220200 Oauth2Client/0.9 (Linux 4_14_186-perf-gddfs8vbb238b) (JAVA 0)", + UserAgent: "ANDROID-com.xunlei.downloadprovider/8.31.0.9726 netWorkType/5G appid/40 deviceName/Xiaomi_M2004j7ac deviceModel/M2004J7AC OSVersion/12 protocolVersion/301 platformVersion/10 sdkVersion/512000 Oauth2Client/0.9 (Linux 4_14_186-perf-gddfs8vbb238b) (JAVA 0)", DownloadUserAgent: "Dalvik/2.1.0 (Linux; U; Android 12; M2004J7AC Build/SP1A.210812.016)", - refreshCTokenCk: func(token string) { x.CaptchaToken = token op.MustSaveDriverStorage(x) @@ -79,6 +83,8 @@ func (x *Thunder) Init(ctx context.Context) (err error) { x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) op.MustSaveDriverStorage(x) } + // 清空 信任密钥 + x.Addition.CreditKey = "" } x.SetTokenResp(token) return err @@ -92,6 +98,17 @@ func (x *Thunder) Init(ctx context.Context) (err error) { x.SetCaptchaToken(ctoekn) } + if x.Addition.CreditKey != "" { + x.SetCreditKey(x.Addition.CreditKey) + } + + if x.Addition.DeviceID != "" { + x.Common.DeviceID = x.Addition.DeviceID + } else { + x.Addition.DeviceID = x.Common.DeviceID + op.MustSaveDriverStorage(x) + } + // 防止重复登录 identity := x.GetIdentity() if x.identity != identity || !x.IsLogin() { @@ -101,6 +118,8 @@ func (x *Thunder) Init(ctx context.Context) (err error) { if err != nil { return err } + // 清空 信任密钥 + x.Addition.CreditKey = "" x.SetTokenResp(token) } return nil @@ -160,6 +179,17 @@ func (x *ThunderExpert) Init(ctx context.Context) (err error) { x.SetCaptchaToken(x.CaptchaToken) } + if x.ExpertAddition.CreditKey != "" { + x.SetCreditKey(x.ExpertAddition.CreditKey) + } + + if x.ExpertAddition.DeviceID != "" { + x.Common.DeviceID = x.ExpertAddition.DeviceID + } else { + x.ExpertAddition.DeviceID = x.Common.DeviceID + op.MustSaveDriverStorage(x) + } + // 签名方法 if x.SignType == "captcha_sign" { x.Common.Timestamp = x.Timestamp @@ -193,6 +223,8 @@ func (x *ThunderExpert) Init(ctx context.Context) (err error) { if err != nil { return err } + // 清空 信任密钥 + x.ExpertAddition.CreditKey = "" x.SetTokenResp(token) x.SetRefreshTokenFunc(func() error { token, err := x.XunLeiCommon.RefreshToken(x.TokenResp.RefreshToken) @@ -201,6 +233,8 @@ func (x *ThunderExpert) Init(ctx context.Context) (err error) { if err != nil { x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) } + // 清空 信任密钥 + x.ExpertAddition.CreditKey = "" } x.SetTokenResp(token) op.MustSaveDriverStorage(x) @@ -232,7 +266,8 @@ func (x *ThunderExpert) SetTokenResp(token *TokenResp) { type XunLeiCommon struct { *Common - *TokenResp // 登录信息 + *TokenResp // 登录信息 + *CoreLoginResp // core登录信息 refreshTokenFunc func() error } @@ -333,22 +368,17 @@ func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error { } func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { - hi := file.GetHash() - gcid := hi.GetHash(hash_extend.GCID) + gcid := file.GetHash().GetHash(hash_extend.GCID) + var err error if len(gcid) < hash_extend.GCID.Width { - tFile, err := file.CacheFullInTempFile() - if err != nil { - return err - } - - gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize()) + _, gcid, err = stream.CacheFullInTempFileAndHash(file, hash_extend.GCID, file.GetSize()) if err != nil { return err } } var resp UploadTaskResponse - _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { r.SetContext(ctx) r.SetBody(&base.Json{ "kind": FILE, @@ -437,6 +467,10 @@ func (xc *XunLeiCommon) SetTokenResp(tr *TokenResp) { xc.TokenResp = tr } +func (xc *XunLeiCommon) SetCoreTokenResp(tr *CoreLoginResp) { + xc.CoreLoginResp = tr +} + // 携带Authorization和CaptchaToken的请求 func (xc *XunLeiCommon) Request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { data, err := xc.Common.Request(url, method, func(req *resty.Request) { @@ -465,7 +499,7 @@ func (xc *XunLeiCommon) Request(url string, method string, callback base.ReqCall } return nil, err case 9: // 验证码token过期 - if err = xc.RefreshCaptchaTokenAtLogin(GetAction(method, url), xc.UserID); err != nil { + if err = xc.RefreshCaptchaTokenAtLogin(GetAction(method, url), xc.TokenResp.UserID); err != nil { return nil, err } default: @@ -497,20 +531,25 @@ func (xc *XunLeiCommon) RefreshToken(refreshToken string) (*TokenResp, error) { // 登录 func (xc *XunLeiCommon) Login(username, password string) (*TokenResp, error) { - url := XLUSER_API_URL + "/auth/signin" - err := xc.RefreshCaptchaTokenInLogin(GetAction(http.MethodPost, url), username) + //v3 login拿到 sessionID + sessionID, err := xc.CoreLogin(username, password) if err != nil { return nil, err } + //v1 login拿到令牌 + url := XLUSER_API_URL + "/auth/signin/token" + if err = xc.RefreshCaptchaTokenInLogin(GetAction(http.MethodPost, url), username); err != nil { + return nil, err + } var resp TokenResp _, err = xc.Common.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetPathParam("client_id", xc.ClientID) req.SetBody(&SignInRequest{ - CaptchaToken: xc.GetCaptchaToken(), ClientID: xc.ClientID, ClientSecret: xc.ClientSecret, - Username: username, - Password: password, + Provider: SignProvider, + SigninToken: sessionID, }) }, &resp) if err != nil { @@ -586,3 +625,48 @@ func (xc *XunLeiCommon) DeleteOfflineTasks(ctx context.Context, taskIDs []string } return nil } + +func (xc *XunLeiCommon) CoreLogin(username string, password string) (sessionID string, err error) { + url := XLUSER_API_BASE_URL + "/xluser.core.login/v3/login" + var resp CoreLoginResp + res, err := xc.Common.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetHeader("User-Agent", "android-ok-http-client/xl-acc-sdk/version-5.0.12.512000") + req.SetBody(&CoreLoginRequest{ + ProtocolVersion: "301", + SequenceNo: "1000012", + PlatformVersion: "10", + IsCompressed: "0", + Appid: APPID, + ClientVersion: "8.31.0.9726", + PeerID: "00000000000000000000000000000000", + AppName: "ANDROID-com.xunlei.downloadprovider", + SdkVersion: "512000", + Devicesign: generateDeviceSign(xc.DeviceID, xc.PackageName), + NetWorkType: "WIFI", + ProviderName: "NONE", + DeviceModel: "M2004J7AC", + DeviceName: "Xiaomi_M2004j7ac", + OSVersion: "12", + Creditkey: xc.GetCreditKey(), + Hl: "zh-CN", + UserName: username, + PassWord: password, + VerifyKey: "", + VerifyCode: "", + IsMd5Pwd: "0", + }) + }, nil) + if err != nil { + return "", err + } + + if err = utils.Json.Unmarshal(res, &resp); err != nil { + return "", err + } + + xc.SetCoreTokenResp(&resp) + + sessionID = resp.SessionID + + return sessionID, nil +} diff --git a/drivers/thunder/meta.go b/drivers/thunder/meta.go index 12b01cba..5e6e2513 100644 --- a/drivers/thunder/meta.go +++ b/drivers/thunder/meta.go @@ -23,23 +23,25 @@ type ExpertAddition struct { RefreshToken string `json:"refresh_token" required:"true" help:"login type is refresh_token,this is required"` // 签名方法1 - Algorithms string `json:"algorithms" required:"true" help:"sign type is algorithms,this is required" default:"HPxr4BVygTQVtQkIMwQH33ywbgYG5l4JoR,GzhNkZ8pOBsCY+7,v+l0ImTpG7c7/,e5ztohgVXNP,t,EbXUWyVVqQbQX39Mbjn2geok3/0WEkAVxeqhtx857++kjJiRheP8l77gO,o7dvYgbRMOpHXxCs,6MW8TD8DphmakaxCqVrfv7NReRRN7ck3KLnXBculD58MvxjFRqT+,kmo0HxCKVfmxoZswLB4bVA/dwqbVAYghSb,j,4scKJNdd7F27Hv7tbt"` + Algorithms string `json:"algorithms" required:"true" help:"sign type is algorithms,this is required" default:"9uJNVj/wLmdwKrJaVj/omlQ,Oz64Lp0GigmChHMf/6TNfxx7O9PyopcczMsnf,Eb+L7Ce+Ej48u,jKY0,ASr0zCl6v8W4aidjPK5KHd1Lq3t+vBFf41dqv5+fnOd,wQlozdg6r1qxh0eRmt3QgNXOvSZO6q/GXK,gmirk+ciAvIgA/cxUUCema47jr/YToixTT+Q6O,5IiCoM9B1/788ntB,P07JH0h6qoM6TSUAK2aL9T5s2QBVeY9JWvalf,+oK0AN"` // 签名方法2 CaptchaSign string `json:"captcha_sign" required:"true" help:"sign type is captcha_sign,this is required"` Timestamp string `json:"timestamp" required:"true" help:"sign type is captcha_sign,this is required"` // 验证码 CaptchaToken string `json:"captcha_token"` + // 信任密钥 + CreditKey string `json:"credit_key" help:"credit key,used for login"` // 必要且影响登录,由签名决定 - DeviceID string `json:"device_id" required:"true" default:"9aa5c268e7bcfc197a9ad88e2fb330e5"` + DeviceID string `json:"device_id" default:""` ClientID string `json:"client_id" required:"true" default:"Xp6vsxz_7IYVw2BB"` ClientSecret string `json:"client_secret" required:"true" default:"Xp6vsy4tN9toTVdMSpomVdXpRmES"` - ClientVersion string `json:"client_version" required:"true" default:"7.51.0.8196"` + ClientVersion string `json:"client_version" required:"true" default:"8.31.0.9726"` PackageName string `json:"package_name" required:"true" default:"com.xunlei.downloadprovider"` //不影响登录,影响下载速度 - UserAgent string `json:"user_agent" required:"true" default:"ANDROID-com.xunlei.downloadprovider/7.51.0.8196 netWorkType/4G appid/40 deviceName/Xiaomi_M2004j7ac deviceModel/M2004J7AC OSVersion/12 protocolVersion/301 platformVersion/10 sdkVersion/220200 Oauth2Client/0.9 (Linux 4_14_186-perf-gdcf98eab238b) (JAVA 0)"` + UserAgent string `json:"user_agent" required:"true" default:"ANDROID-com.xunlei.downloadprovider/8.31.0.9726 netWorkType/5G appid/40 deviceName/Xiaomi_M2004j7ac deviceModel/M2004J7AC OSVersion/12 protocolVersion/301 platformVersion/10 sdkVersion/512000 Oauth2Client/0.9 (Linux 4_14_186-perf-gddfs8vbb238b) (JAVA 0)"` DownloadUserAgent string `json:"download_user_agent" required:"true" default:"Dalvik/2.1.0 (Linux; U; Android 12; M2004J7AC Build/SP1A.210812.016)"` //优先使用视频链接代替下载链接 @@ -74,6 +76,10 @@ type Addition struct { Username string `json:"username" required:"true"` Password string `json:"password" required:"true"` CaptchaToken string `json:"captcha_token"` + // 信任密钥 + CreditKey string `json:"credit_key" help:"credit key,used for login"` + // 登录设备ID + DeviceID string `json:"device_id" default:""` } // 登录特征,用于判断是否重新登录 diff --git a/drivers/thunder/types.go b/drivers/thunder/types.go index b7355b2a..1fe8432c 100644 --- a/drivers/thunder/types.go +++ b/drivers/thunder/types.go @@ -18,6 +18,10 @@ type ErrResp struct { } func (e *ErrResp) IsError() bool { + if e.ErrorMsg == "success" { + return false + } + return e.ErrorCode != 0 || e.ErrorMsg != "" || e.ErrorDescription != "" } @@ -61,13 +65,79 @@ func (t *TokenResp) Token() string { } type SignInRequest struct { - CaptchaToken string `json:"captcha_token"` - ClientID string `json:"client_id"` ClientSecret string `json:"client_secret"` - Username string `json:"username"` - Password string `json:"password"` + Provider string `json:"provider"` + SigninToken string `json:"signin_token"` +} + +type CoreLoginRequest struct { + ProtocolVersion string `json:"protocolVersion"` + SequenceNo string `json:"sequenceNo"` + PlatformVersion string `json:"platformVersion"` + IsCompressed string `json:"isCompressed"` + Appid string `json:"appid"` + ClientVersion string `json:"clientVersion"` + PeerID string `json:"peerID"` + AppName string `json:"appName"` + SdkVersion string `json:"sdkVersion"` + Devicesign string `json:"devicesign"` + NetWorkType string `json:"netWorkType"` + ProviderName string `json:"providerName"` + DeviceModel string `json:"deviceModel"` + DeviceName string `json:"deviceName"` + OSVersion string `json:"OSVersion"` + Creditkey string `json:"creditkey"` + Hl string `json:"hl"` + UserName string `json:"userName"` + PassWord string `json:"passWord"` + VerifyKey string `json:"verifyKey"` + VerifyCode string `json:"verifyCode"` + IsMd5Pwd string `json:"isMd5Pwd"` +} + +type CoreLoginResp struct { + Account string `json:"account"` + Creditkey string `json:"creditkey"` + /* Error string `json:"error"` + ErrorCode string `json:"errorCode"` + ErrorDescription string `json:"error_description"`*/ + ExpiresIn int `json:"expires_in"` + IsCompressed string `json:"isCompressed"` + IsSetPassWord string `json:"isSetPassWord"` + KeepAliveMinPeriod string `json:"keepAliveMinPeriod"` + KeepAlivePeriod string `json:"keepAlivePeriod"` + LoginKey string `json:"loginKey"` + NickName string `json:"nickName"` + PlatformVersion string `json:"platformVersion"` + ProtocolVersion string `json:"protocolVersion"` + SecureKey string `json:"secureKey"` + SequenceNo string `json:"sequenceNo"` + SessionID string `json:"sessionID"` + Timestamp string `json:"timestamp"` + UserID string `json:"userID"` + UserName string `json:"userName"` + UserNewNo string `json:"userNewNo"` + Version string `json:"version"` + /* VipList []struct { + ExpireDate string `json:"expireDate"` + IsAutoDeduct string `json:"isAutoDeduct"` + IsVip string `json:"isVip"` + IsYear string `json:"isYear"` + PayID string `json:"payId"` + PayName string `json:"payName"` + Register string `json:"register"` + Vasid string `json:"vasid"` + VasType string `json:"vasType"` + VipDayGrow string `json:"vipDayGrow"` + VipGrow string `json:"vipGrow"` + VipLevel string `json:"vipLevel"` + Icon struct { + General string `json:"general"` + Small string `json:"small"` + } `json:"icon"` + } `json:"vipList"`*/ } /* @@ -251,3 +321,29 @@ type Params struct { PredictSpeed string `json:"predict_speed"` PredictType string `json:"predict_type"` } + +// LoginReviewResp 登录验证响应 +type LoginReviewResp struct { + Creditkey string `json:"creditkey"` + Error string `json:"error"` + ErrorCode string `json:"errorCode"` + ErrorDesc string `json:"errorDesc"` + ErrorDescURL string `json:"errorDescUrl"` + ErrorIsRetry int `json:"errorIsRetry"` + ErrorDescription string `json:"error_description"` + IsCompressed string `json:"isCompressed"` + PlatformVersion string `json:"platformVersion"` + ProtocolVersion string `json:"protocolVersion"` + Reviewurl string `json:"reviewurl"` + SequenceNo string `json:"sequenceNo"` + UserID string `json:"userID"` + VerifyType string `json:"verifyType"` +} + +// ReviewData 验证数据 +type ReviewData struct { + Creditkey string `json:"creditkey"` + Reviewurl string `json:"reviewurl"` + Deviceid string `json:"deviceid"` + Devicesign string `json:"devicesign"` +} diff --git a/drivers/thunder/util.go b/drivers/thunder/util.go index f509e6b2..b7afe56d 100644 --- a/drivers/thunder/util.go +++ b/drivers/thunder/util.go @@ -1,8 +1,10 @@ package thunder import ( + "crypto/md5" "crypto/sha1" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" @@ -15,10 +17,11 @@ import ( ) const ( - API_URL = "https://api-pan.xunlei.com/drive/v1" - FILE_API_URL = API_URL + "/files" - TASK_API_URL = API_URL + "/tasks" - XLUSER_API_URL = "https://xluser-ssl.xunlei.com/v1" + API_URL = "https://api-pan.xunlei.com/drive/v1" + FILE_API_URL = API_URL + "/files" + TASK_API_URL = API_URL + "/tasks" + XLUSER_API_BASE_URL = "https://xluser-ssl.xunlei.com" + XLUSER_API_URL = XLUSER_API_BASE_URL + "/v1" ) const ( @@ -34,6 +37,12 @@ const ( UPLOAD_TYPE_URL = "UPLOAD_TYPE_URL" ) +const ( + SignProvider = "access_end_point_token" + APPID = "40" + APPKey = "34a062aaa22f906fca4fefe9fb3a3021" +) + func GetAction(method string, url string) string { urlpath := regexp.MustCompile(`://[^/]+((/[^/\s?#]+)*)`).FindStringSubmatch(url)[1] return method + ":" + urlpath @@ -44,6 +53,8 @@ type Common struct { captchaToken string + creditKey string + // 签名相关,二选一 Algorithms []string Timestamp, CaptchaSign string @@ -69,6 +80,13 @@ func (c *Common) GetCaptchaToken() string { return c.captchaToken } +func (c *Common) SetCreditKey(creditKey string) { + c.creditKey = creditKey +} +func (c *Common) GetCreditKey() string { + return c.creditKey +} + // 刷新验证码token(登录后) func (c *Common) RefreshCaptchaTokenAtLogin(action, userID string) error { metas := map[string]string{ @@ -170,12 +188,53 @@ func (c *Common) Request(url, method string, callback base.ReqCallback, resp int var erron ErrResp utils.Json.Unmarshal(res.Body(), &erron) if erron.IsError() { + // review_panel 表示需要短信验证码进行验证 + if erron.ErrorMsg == "review_panel" { + return nil, c.getReviewData(res) + } + return nil, &erron } return res.Body(), nil } +// 获取验证所需内容 +func (c *Common) getReviewData(res *resty.Response) error { + var reviewResp LoginReviewResp + var reviewData ReviewData + + if err := utils.Json.Unmarshal(res.Body(), &reviewResp); err != nil { + return err + } + + deviceSign := generateDeviceSign(c.DeviceID, c.PackageName) + + reviewData = ReviewData{ + Creditkey: reviewResp.Creditkey, + Reviewurl: reviewResp.Reviewurl + "&deviceid=" + deviceSign, + Deviceid: deviceSign, + Devicesign: deviceSign, + } + + // 将reviewData转为JSON字符串 + reviewDataJSON, _ := json.MarshalIndent(reviewData, "", " ") + //reviewDataJSON, _ := json.Marshal(reviewData) + + return fmt.Errorf(` +
+ 🔒 本次登录需要验证
+ This login requires verification + +

下面是验证所需要的数据,具体使用方法请参照对应的驱动文档
+ Below are the relevant verification data. For specific usage methods, please refer to the corresponding driver documentation.

+
+
%s
+
+
`, string(reviewDataJSON)) +} + // 计算文件Gcid func getGcid(r io.Reader, size int64) (string, error) { calcBlockSize := func(j int64) int64 { @@ -201,3 +260,24 @@ func getGcid(r io.Reader, size int64) (string, error) { } return hex.EncodeToString(hash1.Sum(nil)), nil } + +func generateDeviceSign(deviceID, packageName string) string { + + signatureBase := fmt.Sprintf("%s%s%s%s", deviceID, packageName, APPID, APPKey) + + sha1Hash := sha1.New() + sha1Hash.Write([]byte(signatureBase)) + sha1Result := sha1Hash.Sum(nil) + + sha1String := hex.EncodeToString(sha1Result) + + md5Hash := md5.New() + md5Hash.Write([]byte(sha1String)) + md5Result := md5Hash.Sum(nil) + + md5String := hex.EncodeToString(md5Result) + + deviceSign := fmt.Sprintf("div101.%s%s", deviceID, md5String) + + return deviceSign +} diff --git a/drivers/thunder_browser/driver.go b/drivers/thunder_browser/driver.go index 7ce71f7d..0b38d077 100644 --- a/drivers/thunder_browser/driver.go +++ b/drivers/thunder_browser/driver.go @@ -4,10 +4,15 @@ import ( "context" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + streamPkg "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" "github.com/aws/aws-sdk-go/aws" @@ -15,9 +20,6 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/go-resty/resty/v2" - "io" - "net/http" - "strings" ) type ThunderBrowser struct { @@ -456,15 +458,10 @@ func (xc *XunLeiBrowserCommon) Remove(ctx context.Context, obj model.Obj) error } func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - hi := stream.GetHash() - gcid := hi.GetHash(hash_extend.GCID) + gcid := stream.GetHash().GetHash(hash_extend.GCID) + var err error if len(gcid) < hash_extend.GCID.Width { - tFile, err := stream.CacheFullInTempFile() - if err != nil { - return err - } - - gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + _, gcid, err = streamPkg.CacheFullInTempFileAndHash(stream, hash_extend.GCID, stream.GetSize()) if err != nil { return err } @@ -481,7 +478,7 @@ func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream } var resp UploadTaskResponse - _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { r.SetContext(ctx) r.SetBody(&js) }, &resp) diff --git a/drivers/thunderx/driver.go b/drivers/thunderx/driver.go index 2194bdc6..6ee8901a 100644 --- a/drivers/thunderx/driver.go +++ b/drivers/thunderx/driver.go @@ -3,11 +3,15 @@ package thunderx import ( "context" "fmt" + "net/http" + "strings" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" "github.com/aws/aws-sdk-go/aws" @@ -15,8 +19,6 @@ import ( "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/go-resty/resty/v2" - "net/http" - "strings" ) type ThunderX struct { @@ -364,22 +366,17 @@ func (xc *XunLeiXCommon) Remove(ctx context.Context, obj model.Obj) error { } func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { - hi := file.GetHash() - gcid := hi.GetHash(hash_extend.GCID) + gcid := file.GetHash().GetHash(hash_extend.GCID) + var err error if len(gcid) < hash_extend.GCID.Width { - tFile, err := file.CacheFullInTempFile() - if err != nil { - return err - } - - gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize()) + _, gcid, err = stream.CacheFullInTempFileAndHash(file, hash_extend.GCID, file.GetSize()) if err != nil { return err } } var resp UploadTaskResponse - _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + _, err = xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { r.SetContext(ctx) r.SetBody(&base.Json{ "kind": FILE, diff --git a/drivers/url_tree/driver.go b/drivers/url_tree/driver.go index f97d5cc5..049bd2db 100644 --- a/drivers/url_tree/driver.go +++ b/drivers/url_tree/driver.go @@ -243,7 +243,25 @@ func (d *Urls) PutURL(ctx context.Context, dstDir model.Obj, name, url string) ( } func (d *Urls) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - return errs.UploadNotSupported + if !d.Writable { + return errs.PermissionDenied + } + d.mutex.Lock() + defer d.mutex.Unlock() + node := GetNodeFromRootByPath(d.root, dstDir.GetPath()) // parent + if node == nil { + return errs.ObjectNotFound + } + if node.isFile() { + return errs.NotFolder + } + file, err := parseFileLine(stream.GetName(), d.HeadSize) + if err != nil { + return err + } + node.Children = append(node.Children, file) + d.updateStorage() + return nil } func (d *Urls) updateStorage() { diff --git a/go.mod b/go.mod index 97a477d3..e8afe0e7 100644 --- a/go.mod +++ b/go.mod @@ -68,7 +68,7 @@ require ( golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e golang.org/x/image v0.19.0 - golang.org/x/net v0.37.0 + golang.org/x/net v0.38.0 golang.org/x/oauth2 v0.22.0 golang.org/x/time v0.8.0 google.golang.org/appengine v1.6.8 diff --git a/go.sum b/go.sum index 86fb779e..6fbaeb2b 100644 --- a/go.sum +++ b/go.sum @@ -741,6 +741,8 @@ golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= diff --git a/internal/archive/archives/utils.go b/internal/archive/archives/utils.go index fdae1009..2f499a10 100644 --- a/internal/archive/archives/utils.go +++ b/internal/archive/archives/utils.go @@ -10,6 +10,7 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/mholt/archives" ) @@ -73,7 +74,7 @@ func decompress(fsys fs2.FS, filePath, targetPath string, up model.UpdateProgres return err } defer f.Close() - _, err = io.Copy(f, &stream.ReaderUpdatingProgress{ + _, err = utils.CopyWithBuffer(f, &stream.ReaderUpdatingProgress{ Reader: &stream.SimpleReaderWithSize{ Reader: rc, Size: stat.Size(), diff --git a/internal/archive/iso9660/utils.go b/internal/archive/iso9660/utils.go index 12de8e6e..0e4cfb1c 100644 --- a/internal/archive/iso9660/utils.go +++ b/internal/archive/iso9660/utils.go @@ -1,14 +1,15 @@ package iso9660 import ( - "github.com/alist-org/alist/v3/internal/errs" - "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/internal/stream" - "github.com/kdomanski/iso9660" - "io" "os" stdpath "path" "strings" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/kdomanski/iso9660" ) func getImage(ss *stream.SeekableStream) (*iso9660.Image, error) { @@ -66,7 +67,7 @@ func decompress(f *iso9660.File, path string, up model.UpdateProgress) error { return err } defer file.Close() - _, err = io.Copy(file, &stream.ReaderUpdatingProgress{ + _, err = utils.CopyWithBuffer(file, &stream.ReaderUpdatingProgress{ Reader: &stream.SimpleReaderWithSize{ Reader: f.Reader(), Size: f.Size(), diff --git a/internal/conf/config.go b/internal/conf/config.go index 1766ae84..cdb86fee 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -35,6 +35,7 @@ type Scheme struct { KeyFile string `json:"key_file" env:"KEY_FILE"` UnixFile string `json:"unix_file" env:"UNIX_FILE"` UnixFilePerm string `json:"unix_file_perm" env:"UNIX_FILE_PERM"` + EnableH2c bool `json:"enable_h2c" env:"ENABLE_H2C"` } type LogConfig struct { diff --git a/internal/fs/archive.go b/internal/fs/archive.go index b056decf..dbae9b33 100644 --- a/internal/fs/archive.go +++ b/internal/fs/archive.go @@ -90,9 +90,11 @@ func (t *ArchiveDownloadTask) RunWithoutPushUploadTask() (*ArchiveContentUploadT t.SetTotalBytes(total) t.status = "getting src object" for _, s := range ss { - _, err = s.CacheFullInTempFileAndUpdateProgress(func(p float64) { - t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total)) - }) + if s.GetFile() == nil { + _, err = stream.CacheFullInTempFileAndUpdateProgress(s, func(p float64) { + t.SetProgress((float64(cur) + float64(s.GetSize())*p/100.0) / float64(total)) + }) + } cur += s.GetSize() if err != nil { return nil, err diff --git a/internal/model/obj.go b/internal/model/obj.go index 552b1241..f0fce7a1 100644 --- a/internal/model/obj.go +++ b/internal/model/obj.go @@ -2,6 +2,7 @@ package model import ( "io" + "os" "sort" "strings" "time" @@ -48,7 +49,8 @@ type FileStreamer interface { RangeRead(http_range.Range) (io.Reader, error) //for a non-seekable Stream, if Read is called, this function won't work CacheFullInTempFile() (File, error) - CacheFullInTempFileAndUpdateProgress(up UpdateProgress) (File, error) + SetTmpFile(r *os.File) + GetFile() File } type UpdateProgress func(percentage float64) diff --git a/internal/net/request.go b/internal/net/request.go index d4f9321c..a1ff6d20 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -248,8 +248,9 @@ func (d *downloader) sendChunkTask(newConcurrency bool) error { size: finalSize, id: d.nextChunk, buf: buf, + + newConcurrency: newConcurrency, } - ch.newConcurrency = newConcurrency d.pos += finalSize d.nextChunk++ d.chunkChannel <- ch diff --git a/internal/net/serve.go b/internal/net/serve.go index 63e1cb45..bdeac0ac 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -52,19 +52,19 @@ import ( // // If the caller has set w's ETag header formatted per RFC 7232, section 2.3, // ServeHTTP uses it to handle requests using If-Match, If-None-Match, or If-Range. -func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReadCloser model.RangeReadCloserIF) { +func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReadCloser model.RangeReadCloserIF) error { defer RangeReadCloser.Close() setLastModified(w, modTime) done, rangeReq := checkPreconditions(w, r, modTime) if done { - return + return nil } if size < 0 { // since too many functions need file size to work, // will not implement the support of unknown file size here http.Error(w, "negative content size not supported", http.StatusInternalServerError) - return + return nil } code := http.StatusOK @@ -103,7 +103,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time fallthrough default: http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) - return + return nil } if sumRangesSize(ranges) > size { @@ -124,7 +124,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time code = http.StatusTooManyRequests } http.Error(w, err.Error(), code) - return + return nil } sendContent = reader case len(ranges) == 1: @@ -147,7 +147,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time code = http.StatusTooManyRequests } http.Error(w, err.Error(), code) - return + return nil } sendSize = ra.Length code = http.StatusPartialContent @@ -205,9 +205,11 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time if err == ErrExceedMaxConcurrency { code = http.StatusTooManyRequests } - http.Error(w, err.Error(), code) + w.WriteHeader(code) + return err } } + return nil } func ProcessHeader(origin, override http.Header) http.Header { result := http.Header{} diff --git a/internal/op/fs.go b/internal/op/fs.go index 01727e75..64e99335 100644 --- a/internal/op/fs.go +++ b/internal/op/fs.go @@ -3,12 +3,14 @@ package op import ( "context" stdpath "path" + "slices" "time" "github.com/Xhofe/go-cache" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/generic_sync" "github.com/alist-org/alist/v3/pkg/singleflight" "github.com/alist-org/alist/v3/pkg/utils" @@ -25,6 +27,12 @@ func updateCacheObj(storage driver.Driver, path string, oldObj model.Obj, newObj key := Key(storage, path) objs, ok := listCache.Get(key) if ok { + for i, obj := range objs { + if obj.GetName() == newObj.GetName() { + objs = slices.Delete(objs, i, i+1) + break + } + } for i, obj := range objs { if obj.GetName() == oldObj.GetName() { objs[i] = newObj @@ -510,6 +518,12 @@ func Put(ctx context.Context, storage driver.Driver, dstDirPath string, file mod log.Errorf("failed to close file streamer, %v", err) } }() + // UrlTree PUT + if storage.GetStorage().Driver == "UrlTree" { + var link string + dstDirPath, link = urlTreeSplitLineFormPath(stdpath.Join(dstDirPath, file.GetName())) + file = &stream.FileStream{Obj: &model.Object{Name: link}} + } // if file exist and size = 0, delete it dstDirPath = utils.FixAndCleanPath(dstDirPath) dstPath := stdpath.Join(dstDirPath, file.GetName()) diff --git a/internal/op/path.go b/internal/op/path.go index 27f7e183..912a0000 100644 --- a/internal/op/path.go +++ b/internal/op/path.go @@ -2,6 +2,7 @@ package op import ( "github.com/alist-org/alist/v3/internal/errs" + stdpath "path" "strings" "github.com/alist-org/alist/v3/internal/driver" @@ -27,3 +28,30 @@ func GetStorageAndActualPath(rawPath string) (storage driver.Driver, actualPath actualPath = utils.FixAndCleanPath(strings.TrimPrefix(rawPath, mountPath)) return } + +// urlTreeSplitLineFormPath 分割path中分割真实路径和UrlTree定义字符串 +func urlTreeSplitLineFormPath(path string) (pp string, file string) { + // url.PathUnescape 会移除 // ,手动加回去 + path = strings.Replace(path, "https:/", "https://", 1) + path = strings.Replace(path, "http:/", "http://", 1) + if strings.Contains(path, ":https:/") || strings.Contains(path, ":http:/") { + // URL-Tree模式 /url_tree_drivr/file_name[:size[:time]]:https://example.com/file + fPath := strings.SplitN(path, ":", 2)[0] + pp, _ = stdpath.Split(fPath) + file = path[len(pp):] + } else if strings.Contains(path, "/https:/") || strings.Contains(path, "/http:/") { + // URL-Tree模式 /url_tree_drivr/https://example.com/file + index := strings.Index(path, "/http://") + if index == -1 { + index = strings.Index(path, "/https://") + } + pp = path[:index] + file = path[index+1:] + } else { + pp, file = stdpath.Split(path) + } + if pp == "" { + pp = "/" + } + return +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go index f6b045a0..64160915 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -94,27 +94,17 @@ func (f *FileStream) CacheFullInTempFile() (model.File, error) { f.Add(tmpF) f.tmpFile = tmpF f.Reader = tmpF - return f.tmpFile, nil + return tmpF, nil } -func (f *FileStream) CacheFullInTempFileAndUpdateProgress(up model.UpdateProgress) (model.File, error) { +func (f *FileStream) GetFile() model.File { if f.tmpFile != nil { - return f.tmpFile, nil + return f.tmpFile } if file, ok := f.Reader.(model.File); ok { - return file, nil + return file } - tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ - Reader: f, - UpdateProgress: up, - }, f.GetSize()) - if err != nil { - return nil, err - } - f.Add(tmpF) - f.tmpFile = tmpF - f.Reader = tmpF - return f.tmpFile, nil + return nil } const InMemoryBufMaxSize = 10 // Megabytes @@ -127,31 +117,36 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { // 参考 internal/net/request.go httpRange.Length = f.GetSize() - httpRange.Start } - if f.peekBuff != nil && httpRange.Start < int64(f.peekBuff.Len()) && httpRange.Start+httpRange.Length-1 < int64(f.peekBuff.Len()) { + size := httpRange.Start + httpRange.Length + if f.peekBuff != nil && size <= int64(f.peekBuff.Len()) { return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil } - if f.tmpFile == nil { - if httpRange.Start == 0 && httpRange.Length <= InMemoryBufMaxSizeBytes && f.peekBuff == nil { - bufSize := utils.Min(httpRange.Length, f.GetSize()) - newBuf := bytes.NewBuffer(make([]byte, 0, bufSize)) - n, err := utils.CopyWithBufferN(newBuf, f.Reader, bufSize) + var cache io.ReaderAt = f.GetFile() + if cache == nil { + if size <= InMemoryBufMaxSizeBytes { + bufSize := min(size, f.GetSize()) + // 使用bytes.Buffer作为io.CopyBuffer的写入对象,CopyBuffer会调用Buffer.ReadFrom + // 即使被写入的数据量与Buffer.Cap一致,Buffer也会扩大 + buf := make([]byte, bufSize) + n, err := io.ReadFull(f.Reader, buf) if err != nil { return nil, err } - if n != bufSize { + if n != int(bufSize) { return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n) } - f.peekBuff = bytes.NewReader(newBuf.Bytes()) + f.peekBuff = bytes.NewReader(buf) f.Reader = io.MultiReader(f.peekBuff, f.Reader) - return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil + cache = f.peekBuff } else { - _, err := f.CacheFullInTempFile() + var err error + cache, err = f.CacheFullInTempFile() if err != nil { return nil, err } } } - return io.NewSectionReader(f.tmpFile, httpRange.Start, httpRange.Length), nil + return io.NewSectionReader(cache, httpRange.Start, httpRange.Length), nil } var _ model.FileStreamer = (*SeekableStream)(nil) @@ -176,13 +171,13 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) if len(fs.Mimetype) == 0 { fs.Mimetype = utils.GetMimeType(fs.Obj.GetName()) } - ss := SeekableStream{FileStream: fs, Link: link} + ss := &SeekableStream{FileStream: fs, Link: link} if ss.Reader != nil { result, ok := ss.Reader.(model.File) if ok { ss.mFile = result ss.Closers.Add(result) - return &ss, nil + return ss, nil } } if ss.Link != nil { @@ -198,7 +193,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) ss.mFile = mFile ss.Reader = mFile ss.Closers.Add(mFile) - return &ss, nil + return ss, nil } if ss.Link.RangeReadCloser != nil { ss.rangeReadCloser = &RateLimitRangeReadCloser{ @@ -206,7 +201,7 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) Limiter: ServerDownloadLimit, } ss.Add(ss.rangeReadCloser) - return &ss, nil + return ss, nil } if len(ss.Link.URL) > 0 { rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link) @@ -219,10 +214,12 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) } ss.rangeReadCloser = rrc ss.Add(rrc) - return &ss, nil + return ss, nil } } - + if fs.Reader != nil { + return ss, nil + } return nil, fmt.Errorf("illegal seekableStream") } @@ -248,7 +245,7 @@ func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, erro } return rc, nil } - return nil, fmt.Errorf("can't find mFile or rangeReadCloser") + return ss.FileStream.RangeRead(httpRange) } //func (f *FileStream) GetReader() io.Reader { @@ -278,7 +275,7 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { if ss.tmpFile != nil { return ss.tmpFile, nil } - if _, ok := ss.mFile.(*os.File); ok { + if ss.mFile != nil { return ss.mFile, nil } tmpF, err := utils.CreateTempFile(ss, ss.GetSize()) @@ -288,27 +285,17 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { ss.Add(tmpF) ss.tmpFile = tmpF ss.Reader = tmpF - return ss.tmpFile, nil + return tmpF, nil } -func (ss *SeekableStream) CacheFullInTempFileAndUpdateProgress(up model.UpdateProgress) (model.File, error) { +func (ss *SeekableStream) GetFile() model.File { if ss.tmpFile != nil { - return ss.tmpFile, nil + return ss.tmpFile } - if _, ok := ss.mFile.(*os.File); ok { - return ss.mFile, nil + if ss.mFile != nil { + return ss.mFile } - tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ - Reader: ss, - UpdateProgress: up, - }, ss.GetSize()) - if err != nil { - return nil, err - } - ss.Add(tmpF) - ss.tmpFile = tmpF - ss.Reader = tmpF - return ss.tmpFile, nil + return nil } func (f *FileStream) SetTmpFile(r *os.File) { diff --git a/internal/stream/util.go b/internal/stream/util.go index 01019482..5b935a90 100644 --- a/internal/stream/util.go +++ b/internal/stream/util.go @@ -2,6 +2,7 @@ package stream import ( "context" + "encoding/hex" "fmt" "io" "net/http" @@ -96,3 +97,45 @@ func (r *ReaderWithCtx) Close() error { } return nil } + +func CacheFullInTempFileAndUpdateProgress(stream model.FileStreamer, up model.UpdateProgress) (model.File, error) { + if cache := stream.GetFile(); cache != nil { + up(100) + return cache, nil + } + tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{ + Reader: stream, + UpdateProgress: up, + }, stream.GetSize()) + if err == nil { + stream.SetTmpFile(tmpF) + } + return tmpF, err +} + +func CacheFullInTempFileAndWriter(stream model.FileStreamer, w io.Writer) (model.File, error) { + if cache := stream.GetFile(); cache != nil { + _, err := cache.Seek(0, io.SeekStart) + if err == nil { + _, err = utils.CopyWithBuffer(w, cache) + if err == nil { + _, err = cache.Seek(0, io.SeekStart) + } + } + return cache, err + } + tmpF, err := utils.CreateTempFile(io.TeeReader(stream, w), stream.GetSize()) + if err == nil { + stream.SetTmpFile(tmpF) + } + return tmpF, err +} + +func CacheFullInTempFileAndHash(stream model.FileStreamer, hashType *utils.HashType, params ...any) (model.File, string, error) { + h := hashType.NewFunc(params...) + tmpF, err := CacheFullInTempFileAndWriter(stream, h) + if err != nil { + return nil, "", err + } + return tmpF, hex.EncodeToString(h.Sum(nil)), err +} diff --git a/server/common/proxy.go b/server/common/proxy.go index f9e1e4bb..ca7f6325 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -39,11 +39,10 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. return nil } else if link.RangeReadCloser != nil { attachHeader(w, file) - net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ + return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ RangeReadCloserIF: link.RangeReadCloser, Limiter: stream.ServerDownloadLimit, }) - return nil } else if link.Concurrency != 0 || link.PartSize != 0 { attachHeader(w, file) size := file.GetSize() @@ -66,11 +65,10 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. rc, err := down.Download(ctx, req) return rc, err } - net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ + return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ RangeReadCloserIF: &model.RangeReadCloser{RangeReader: rangeReader}, Limiter: stream.ServerDownloadLimit, }) - return nil } else { //transparent proxy header := net.ProcessHeader(r.Header, link.Header) @@ -90,10 +88,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. Limiter: stream.ServerDownloadLimit, Ctx: r.Context(), }) - if err != nil { - return err - } - return nil + return err } } func attachHeader(w http.ResponseWriter, file model.Obj) { @@ -133,3 +128,29 @@ func ProxyRange(link *model.Link, size int64) { link.RangeReadCloser = nil } } + +type InterceptResponseWriter struct { + http.ResponseWriter + io.Writer +} + +func (iw *InterceptResponseWriter) Write(p []byte) (int, error) { + return iw.Writer.Write(p) +} + +type WrittenResponseWriter struct { + http.ResponseWriter + written bool +} + +func (ww *WrittenResponseWriter) Write(p []byte) (int, error) { + n, err := ww.ResponseWriter.Write(p) + if !ww.written && n > 0 { + ww.written = true + } + return n, err +} + +func (ww *WrittenResponseWriter) IsWritten() bool { + return ww.written +} diff --git a/server/handles/down.go b/server/handles/down.go index 1153881f..2c5c2faf 100644 --- a/server/handles/down.go +++ b/server/handles/down.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "net/http" stdpath "path" "strconv" "strings" @@ -129,15 +128,16 @@ func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange boo if proxyRange { common.ProxyRange(link, file.GetSize()) } + Writer := &common.WrittenResponseWriter{ResponseWriter: c.Writer} //优先处理md文件 if utils.Ext(file.GetName()) == "md" && setting.GetBool(conf.FilterReadMeScripts) { - w := c.Writer buf := bytes.NewBuffer(make([]byte, 0, file.GetSize())) - err = common.Proxy(&proxyResponseWriter{ResponseWriter: w, Writer: buf}, c.Request, link, file) + w := &common.InterceptResponseWriter{ResponseWriter: Writer, Writer: buf} + err = common.Proxy(w, c.Request, link, file) if err == nil && buf.Len() > 0 { - if w.Status() < 200 || w.Status() > 300 { - w.Write(buf.Bytes()) + if c.Writer.Status() < 200 || c.Writer.Status() > 300 { + c.Writer.Write(buf.Bytes()) return } @@ -148,19 +148,23 @@ func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange boo buf.Reset() err = bluemonday.UGCPolicy().SanitizeReaderToWriter(&html, buf) if err == nil { - w.Header().Set("Content-Length", strconv.FormatInt(int64(buf.Len()), 10)) - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, err = utils.CopyWithBuffer(c.Writer, buf) + Writer.Header().Set("Content-Length", strconv.FormatInt(int64(buf.Len()), 10)) + Writer.Header().Set("Content-Type", "text/html; charset=utf-8") + _, err = utils.CopyWithBuffer(Writer, buf) } } } } else { - err = common.Proxy(c.Writer, c.Request, link, file) + err = common.Proxy(Writer, c.Request, link, file) } - if err != nil { - common.ErrorResp(c, err, 500, true) + if err == nil { return } + if Writer.IsWritten() { + log.Errorf("%s %s local proxy error: %+v", c.Request.Method, c.Request.URL.Path, err) + } else { + common.ErrorResp(c, err, 500, true) + } } // TODO need optimize @@ -182,12 +186,3 @@ func canProxy(storage driver.Driver, filename string) bool { } return false } - -type proxyResponseWriter struct { - http.ResponseWriter - io.Writer -} - -func (pw *proxyResponseWriter) Write(p []byte) (int, error) { - return pw.Writer.Write(p) -} diff --git a/server/handles/fsup.go b/server/handles/fsup.go index 15a6328b..41344fb8 100644 --- a/server/handles/fsup.go +++ b/server/handles/fsup.go @@ -1,8 +1,6 @@ package handles import ( - "github.com/alist-org/alist/v3/internal/task" - "github.com/alist-org/alist/v3/pkg/utils" "io" "net/url" stdpath "path" @@ -12,6 +10,8 @@ import ( "github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/internal/task" + "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" ) @@ -44,7 +44,7 @@ func FsStream(c *gin.Context) { } if !overwrite { if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { - _, _ = io.Copy(io.Discard, c.Request.Body) + _, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body) common.ErrorStrResp(c, "file exists", 403) return } @@ -66,6 +66,10 @@ func FsStream(c *gin.Context) { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { h[utils.SHA256] = sha256 } + mimetype := c.GetHeader("Content-Type") + if len(mimetype) == 0 { + mimetype = utils.GetMimeType(name) + } s := &stream.FileStream{ Obj: &model.Object{ Name: name, @@ -74,7 +78,7 @@ func FsStream(c *gin.Context) { HashInfo: utils.NewHashInfoByMap(h), }, Reader: c.Request.Body, - Mimetype: c.GetHeader("Content-Type"), + Mimetype: mimetype, WebPutAsTask: asTask, } var t task.TaskExtensionInfo @@ -89,6 +93,9 @@ func FsStream(c *gin.Context) { return } if t == nil { + if n, _ := io.ReadFull(c.Request.Body, []byte{0}); n == 1 { + _, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body) + } common.SuccessResp(c) return } @@ -114,7 +121,7 @@ func FsForm(c *gin.Context) { } if !overwrite { if res, _ := fs.Get(c, path, &fs.GetArgs{NoLog: true}); res != nil { - _, _ = io.Copy(io.Discard, c.Request.Body) + _, _ = utils.CopyWithBuffer(io.Discard, c.Request.Body) common.ErrorStrResp(c, "file exists", 403) return } @@ -150,6 +157,10 @@ func FsForm(c *gin.Context) { if sha256 := c.GetHeader("X-File-Sha256"); sha256 != "" { h[utils.SHA256] = sha256 } + mimetype := file.Header.Get("Content-Type") + if len(mimetype) == 0 { + mimetype = utils.GetMimeType(name) + } s := stream.FileStream{ Obj: &model.Object{ Name: name, @@ -158,7 +169,7 @@ func FsForm(c *gin.Context) { HashInfo: utils.NewHashInfoByMap(h), }, Reader: f, - Mimetype: file.Header.Get("Content-Type"), + Mimetype: mimetype, WebPutAsTask: asTask, } var t task.TaskExtensionInfo @@ -168,12 +179,7 @@ func FsForm(c *gin.Context) { }{f} t, err = fs.PutAsTask(c, dir, &s) } else { - ss, err := stream.NewSeekableStream(s, nil) - if err != nil { - common.ErrorResp(c, err, 500) - return - } - err = fs.PutDirectly(c, dir, ss, true) + err = fs.PutDirectly(c, dir, &s, true) } if err != nil { common.ErrorResp(c, err, 500) diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index 1b7ec6ff..f22e15aa 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -24,7 +24,6 @@ import ( "github.com/alist-org/alist/v3/internal/sign" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" - log "github.com/sirupsen/logrus" ) type Handler struct { @@ -59,7 +58,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { status, err = h.handleOptions(brw, r) case "GET", "HEAD", "POST": useBufferedWriter = false - status, err = h.handleGetHeadPost(w, r) + Writer := &common.WrittenResponseWriter{ResponseWriter: w} + status, err = h.handleGetHeadPost(Writer, r) + if status != 0 && Writer.IsWritten() { + status = 0 + } case "DELETE": status, err = h.handleDelete(brw, r) case "PUT": @@ -247,8 +250,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta } err = common.Proxy(w, r, link, fi) if err != nil { - log.Errorf("webdav proxy error: %+v", err) - return http.StatusInternalServerError, err + return http.StatusInternalServerError, fmt.Errorf("webdav proxy error: %+v", err) } } else if storage.GetStorage().WebdavProxy() && downProxyUrl != "" { u := fmt.Sprintf("%s%s?sign=%s",