diff --git a/drivers/crypt/util.go b/drivers/crypt/util.go index 908af2aa..f4246756 100644 --- a/drivers/crypt/util.go +++ b/drivers/crypt/util.go @@ -13,7 +13,7 @@ import ( ) func RequestRangedHttp(r *http.Request, link *model.Link, offset, length int64) (*http.Response, error) { - header := net.ProcessHeader(&http.Header{}, &link.Header) + header := net.ProcessHeader(http.Header{}, link.Header) header = http_range.ApplyRangeToHttpHeader(http_range.Range{Start: offset, Length: length}, header) return net.RequestHttp("GET", header, link.URL) diff --git a/internal/fs/copy.go b/internal/fs/copy.go index 87735f2a..b8f92599 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -94,7 +94,7 @@ func copyFileBetween2Storages(tsk *task.Task[uint64], srcStorage, dstStorage dri if err != nil { return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) } - stream, err := getFileStreamFromLink(srcFile, link) + stream, err := getFileStreamFromLink(tsk.Ctx, srcFile, link) if err != nil { return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath) } diff --git a/internal/fs/util.go b/internal/fs/util.go index 10b9c473..5eca5fce 100644 --- a/internal/fs/util.go +++ b/internal/fs/util.go @@ -1,18 +1,21 @@ package fs import ( - "github.com/alist-org/alist/v3/pkg/http_range" + "context" "io" "net/http" "strings" + "github.com/alist-org/alist/v3/internal/net" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/pkg/errors" ) -func getFileStreamFromLink(file model.Obj, link *model.Link) (*model.FileStream, error) { +func getFileStreamFromLink(ctx context.Context, file model.Obj, link *model.Link) (*model.FileStream, error) { var rc io.ReadCloser var err error mimetype := utils.GetMimeType(file.GetName()) @@ -23,6 +26,21 @@ func getFileStreamFromLink(file model.Obj, link *model.Link) (*model.FileStream, } } else if link.ReadSeekCloser != nil { rc = link.ReadSeekCloser + } else if link.Concurrency != 0 || link.PartSize != 0 { + down := net.NewDownloader(func(d *net.Downloader) { + d.Concurrency = link.Concurrency + d.PartSize = link.PartSize + }) + req := &net.HttpRequestParams{ + URL: link.URL, + Range: http_range.Range{Length: -1}, + Size: file.GetSize(), + HeaderRef: link.Header, + } + rc, err = down.Download(ctx, req) + if err != nil { + return nil, err + } } else { //TODO: add accelerator req, err := http.NewRequest(http.MethodGet, link.URL, nil) diff --git a/internal/net/request.go b/internal/net/request.go index bb5b68bb..087e44b2 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -3,9 +3,6 @@ package net import ( "context" "fmt" - "github.com/alist-org/alist/v3/pkg/http_range" - "github.com/aws/aws-sdk-go/aws/awsutil" - log "github.com/sirupsen/logrus" "io" "math" "net/http" @@ -13,6 +10,10 @@ import ( "strings" "sync" "time" + + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/aws/aws-sdk-go/aws/awsutil" + log "github.com/sirupsen/logrus" ) // DefaultDownloadPartSize is the default range of bytes to get at a time when @@ -60,7 +61,7 @@ func NewDownloader(options ...func(*Downloader)) *Downloader { // cache some data, then return Reader with assembled data // Supports range, do not support unknown FileSize, and will fail if FileSize is incorrect // memory usage is at about Concurrency*PartSize, use this wisely -func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readCloser *io.ReadCloser, err error) { +func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readCloser io.ReadCloser, err error) { var finalP HttpRequestParams awsutil.Copy(&finalP, p) @@ -107,7 +108,7 @@ type downloader struct { } // download performs the implementation of the object download across ranged GETs. -func (d *downloader) download() (*io.ReadCloser, error) { +func (d *downloader) download() (io.ReadCloser, error) { d.ctx, d.cancel = context.WithCancel(d.ctx) pos := d.params.Range.Start @@ -133,7 +134,7 @@ func (d *downloader) download() (*io.ReadCloser, error) { if err != nil { return nil, err } - return &resp.Body, nil + return resp.Body, nil } // workers @@ -152,7 +153,7 @@ func (d *downloader) download() (*io.ReadCloser, error) { var rc io.ReadCloser = NewMultiReadCloser(d.chunks[0].buf, d.interrupt, d.finishBuf) // Return error - return &rc, d.err + return rc, d.err } func (d *downloader) sendChunkTask() *chunk { ch := &d.chunks[d.nextChunk] @@ -384,7 +385,7 @@ type HttpRequestParams struct { URL string //only want data within this range Range http_range.Range - HeaderRef *http.Header + HeaderRef http.Header //total file size Size int64 } diff --git a/internal/net/serve.go b/internal/net/serve.go index 83349368..b2da536a 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -2,13 +2,6 @@ package net import ( "fmt" - "github.com/alist-org/alist/v3/drivers/base" - "github.com/alist-org/alist/v3/internal/conf" - "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/pkg/http_range" - "github.com/alist-org/alist/v3/pkg/utils" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" "io" "mime" "mime/multipart" @@ -18,6 +11,14 @@ import ( "strings" "sync" "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" ) //this file is inspired by GO_SDK net.http.ServeContent @@ -109,7 +110,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time } switch { case len(ranges) == 0: - reader, err := RangeReaderFunc(http_range.Range{0, -1}) + reader, err := RangeReaderFunc(http_range.Range{Length: -1}) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -191,29 +192,29 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time } //defer sendContent.Close() } -func ProcessHeader(origin, override *http.Header) *http.Header { +func ProcessHeader(origin, override http.Header) http.Header { result := http.Header{} // client header - for h, val := range *origin { + for h, val := range origin { if utils.SliceContains(conf.SlicesMap[conf.ProxyIgnoreHeaders], strings.ToLower(h)) { continue } result[h] = val } // needed header - for h, val := range *override { + for h, val := range override { result[h] = val } - return &result + return result } // RequestHttp deal with Header properly then send the request -func RequestHttp(httpMethod string, headerOverride *http.Header, URL string) (*http.Response, error) { +func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*http.Response, error) { req, err := http.NewRequest(httpMethod, URL, nil) if err != nil { return nil, err } - req.Header = *headerOverride + req.Header = headerOverride log.Debugln("request Header: ", req.Header) log.Debugln("request URL: ", URL) res, err := HttpClient().Do(req) diff --git a/pkg/http_range/range.go b/pkg/http_range/range.go index 4a7d2703..0d6598f2 100644 --- a/pkg/http_range/range.go +++ b/pkg/http_range/range.go @@ -120,10 +120,10 @@ func (r Range) contentRange(size int64) string { } // ApplyRangeToHttpHeader for http request header -func ApplyRangeToHttpHeader(p Range, headerRef *http.Header) *http.Header { +func ApplyRangeToHttpHeader(p Range, headerRef http.Header) http.Header { header := headerRef if header == nil { - header = &http.Header{} + header = http.Header{} } if p.Start == 0 && p.Length < 0 { header.Del("Range") diff --git a/server/common/proxy.go b/server/common/proxy.go index f6148860..45c2b820 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -3,16 +3,17 @@ package common import ( "context" "fmt" + "io" + "net/http" + "net/url" + "sync" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/net" "github.com/alist-org/alist/v3/pkg/http_range" "github.com/alist-org/alist/v3/pkg/utils" "github.com/pkg/errors" - "io" - "net/http" - "net/url" - "sync" ) func HttpClient() *http.Client { @@ -52,7 +53,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. size := file.GetSize() //var finalClosers model.Closers finalClosers := utils.NewClosers() - header := net.ProcessHeader(&r.Header, &link.Header) + header := net.ProcessHeader(r.Header, link.Header) rangeReader := func(httpRange http_range.Range) (io.ReadCloser, error) { down := net.NewDownloader(func(d *net.Downloader) { d.Concurrency = link.Concurrency @@ -65,15 +66,15 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. HeaderRef: header, } rc, err := down.Download(context.Background(), req) - finalClosers.Add(*rc) - return *rc, err + finalClosers.Add(rc) + return rc, err } net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), rangeReader) defer finalClosers.Close() return nil } else { //transparent proxy - header := net.ProcessHeader(&r.Header, &link.Header) + header := net.ProcessHeader(r.Header, link.Header) res, err := net.RequestHttp(r.Method, header, link.URL) if err != nil { return err