diff --git a/internal/net/request.go b/internal/net/request.go index 087e44b2..0bcd966d 100644 --- a/internal/net/request.go +++ b/internal/net/request.go @@ -1,6 +1,7 @@ package net import ( + "bytes" "context" "fmt" "io" @@ -202,7 +203,6 @@ func (d *downloader) downloadPart() { //defer d.wg.Done() for { c, ok := <-d.chunkChannel - log.Debugf("downloadPart tried to get chunk") if !ok { break } @@ -211,7 +211,7 @@ func (d *downloader) downloadPart() { // of download producer. continue } - + log.Debugf("downloadPart tried to get chunk") if err := d.downloadChunk(&c); err != nil { d.setErr(err) } @@ -220,7 +220,7 @@ func (d *downloader) downloadPart() { // downloadChunk downloads the chunk func (d *downloader) downloadChunk(ch *chunk) error { - log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.buf.buffer.id) + log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.id) var n int64 var err error params := d.getParamsFromChunk(ch) @@ -262,6 +262,7 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int if err != nil { return 0, err } + defer resp.Body.Close() //only check file size on the first task if ch.id == 0 { err = d.checkTotalBytes(resp) @@ -279,7 +280,6 @@ func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n) return n, &errReadingBody{err: err} } - defer resp.Body.Close() return n, nil } @@ -402,13 +402,8 @@ func (e *errReadingBody) Unwrap() error { } type MultiReadCloser struct { - io.ReadCloser - - //total int //total bufArr - //wPos int //current reader wPos cfg *cfg closer closerFunc - //getBuf getBufFunc finish finishBufFUnc } @@ -449,99 +444,26 @@ func (mr MultiReadCloser) Close() error { return mr.closer() } -type Buffer struct { - data []byte - wPos int //writer position - id int - rPos int //reader position - lock sync.Mutex - - once bool //combined use with notify & lock, to get notify once - notify chan int // notifies new writes -} - -func (buf *Buffer) Write(p []byte) (n int, err error) { - inSize := len(p) - if inSize == 0 { - return 0, nil - } - - if inSize > len(buf.data)-buf.wPos { - return 0, fmt.Errorf("exceeding buffer max size,inSize=%d ,buf.data.len=%d , buf.wPos=%d", - inSize, len(buf.data), buf.wPos) - } - copy(buf.data[buf.wPos:], p) - buf.wPos += inSize - - //give read a notice if once==true - buf.lock.Lock() - if buf.once == true { - buf.notify <- inSize //struct{}{} - } - buf.once = false - buf.lock.Unlock() - - return inSize, nil -} - -func (buf *Buffer) getPos() (n int) { - return buf.wPos -} -func (buf *Buffer) reset() { - buf.wPos = 0 - buf.rPos = 0 -} - -// waitTillNewWrite notify caller that new write happens -func (buf *Buffer) waitTillNewWrite(pos int) error { - //log.Debugf("waitTillNewWrite, current wPos=%d", pos) - var err error - - //defer buffer.lock.Unlock() - if pos >= len(buf.data) { - err = fmt.Errorf("there will not be any new write") - } else if pos > buf.wPos { - err = fmt.Errorf("illegal read position") - } else if pos == buf.wPos { - buf.lock.Lock() - buf.once = true - //buffer.wg1.Add(1) - buf.lock.Unlock() - //wait for write - log.Debugf("waitTillNewWrite wait for notify") - writes := <-buf.notify - log.Debugf("waitTillNewWrite got new write from notify, last writes:%+v", writes) - //if pos >= buf.wPos { - // //wrote 0 bytes - // return fmt.Errorf("write has error") - //} - return nil - } - //only case: wPos < buffer.wPos - return err -} - type Buf struct { - buffer *Buffer // Buffer we read from - size int //expected size + buffer *bytes.Buffer + size int //expected size ctx context.Context + off int + rw sync.RWMutex + notify chan struct{} } // NewBuf is a buffer that can have 1 read & 1 write at the same time. // when read is faster write, immediately feed data to read after written func NewBuf(ctx context.Context, maxSize int, id int) *Buf { d := make([]byte, maxSize) - buffer := &Buffer{data: d, id: id, notify: make(chan int)} - buffer.reset() - return &Buf{ctx: ctx, buffer: buffer, size: maxSize} + return &Buf{ctx: ctx, buffer: bytes.NewBuffer(d), size: maxSize, notify: make(chan struct{})} } func (br *Buf) Reset(size int) { - br.buffer.reset() + br.buffer.Reset() br.size = size -} -func (br *Buf) GetId() int { - return br.buffer.id + br.off = 0 } func (br *Buf) Read(p []byte) (n int, err error) { @@ -551,48 +473,49 @@ func (br *Buf) Read(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } - if br.buffer.rPos == br.size { + if br.off >= br.size { return 0, io.EOF } - //persist buffer position as another thread is keep increasing it - bufPos := br.buffer.getPos() - outSize := bufPos - br.buffer.rPos - - if outSize == 0 { - //var wg sync.WaitGroup - err := br.waitTillNewWrite(br.buffer.rPos) - if err != nil { - return 0, err - } - bufPos = br.buffer.getPos() - outSize = bufPos - br.buffer.rPos + br.rw.RLock() + n, err = br.buffer.Read(p) + br.rw.RUnlock() + if err == nil { + br.off += n + return n, err } - - if len(p) < outSize { - // p is not big enough - outSize = len(p) + if err != io.EOF { + return n, err } - copy(p, br.buffer.data[br.buffer.rPos:br.buffer.rPos+outSize]) - br.buffer.rPos += outSize - if br.buffer.rPos == br.size { - err = io.EOF + if n != 0 { + br.off += n + return n, nil + } + // n==0, err==io.EOF + // wait for new write for 200ms + select { + case <-br.ctx.Done(): + return 0, br.ctx.Err() + case <-br.notify: + return 0, nil + case <-time.After(time.Millisecond * 200): + return 0, nil } - - return outSize, err -} - -// waitTillNewWrite is expensive, since we just checked that no new data, wait 0.2s -func (br *Buf) waitTillNewWrite(pos int) error { - time.Sleep(200 * time.Millisecond) - return br.buffer.waitTillNewWrite(br.buffer.rPos) } func (br *Buf) Write(p []byte) (n int, err error) { if err := br.ctx.Err(); err != nil { return 0, err } - return br.buffer.Write(p) + br.rw.Lock() + defer br.rw.Unlock() + n, err = br.buffer.Write(p) + select { + case br.notify <- struct{}{}: + default: + } + return } + func (br *Buf) Close() { - close(br.buffer.notify) + close(br.notify) } diff --git a/internal/net/request_test.go b/internal/net/request_test.go index 39bfd82a..edc38c1c 100644 --- a/internal/net/request_test.go +++ b/internal/net/request_test.go @@ -7,14 +7,15 @@ import ( "bytes" "context" "fmt" - "github.com/alist-org/alist/v3/pkg/http_range" - "github.com/sirupsen/logrus" - "golang.org/x/exp/slices" "io" "io/ioutil" "net/http" "sync" "testing" + + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" ) var buf22MB = make([]byte, 1024*1024*22) @@ -55,7 +56,7 @@ func TestDownloadOrder(t *testing.T) { if err != nil { t.Fatalf("expect no error, got %v", err) } - resultBuf, err := io.ReadAll(*readCloser) + resultBuf, err := io.ReadAll(readCloser) if err != nil { t.Fatalf("expect no error, got %v", err) } @@ -111,7 +112,7 @@ func TestDownloadSingle(t *testing.T) { if err != nil { t.Fatalf("expect no error, got %v", err) } - resultBuf, err := io.ReadAll(*readCloser) + resultBuf, err := io.ReadAll(readCloser) if err != nil { t.Fatalf("expect no error, got %v", err) } diff --git a/internal/net/serve.go b/internal/net/serve.go index b2da536a..688882b9 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -215,14 +215,10 @@ func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*ht return nil, err } req.Header = headerOverride - log.Debugln("request Header: ", req.Header) - log.Debugln("request URL: ", URL) res, err := HttpClient().Do(req) if err != nil { return nil, err } - log.Debugf("response status: %d", res.StatusCode) - log.Debugln("response Header: ", res.Header) // TODO clean header with blocklist or passlist res.Header.Del("set-cookie") if res.StatusCode >= 400 { @@ -231,7 +227,6 @@ func RequestHttp(httpMethod string, headerOverride http.Header, URL string) (*ht log.Debugln(msg) return res, errors.New(msg) } - return res, nil } diff --git a/pkg/http_range/range.go b/pkg/http_range/range.go index 0d6598f2..6a5451a1 100644 --- a/pkg/http_range/range.go +++ b/pkg/http_range/range.go @@ -109,16 +109,11 @@ func ParseRange(s string, size int64) ([]Range, error) { // nolint:gocognit func (r Range) MimeHeader(contentType string, size int64) textproto.MIMEHeader { return textproto.MIMEHeader{ - "Content-Range": {r.contentRange(size)}, + "Content-Range": {r.ContentRange(size)}, "Content-Type": {contentType}, } } -// for http response header -func (r Range) contentRange(size int64) string { - return fmt.Sprintf("bytes %d-%d/%d", r.Start, r.Start+r.Length-1, size) -} - // ApplyRangeToHttpHeader for http request header func ApplyRangeToHttpHeader(p Range, headerRef http.Header) http.Header { header := headerRef