diff --git a/internal/net/serve.go b/internal/net/serve.go index 63e1cb45..bdeac0ac 100644 --- a/internal/net/serve.go +++ b/internal/net/serve.go @@ -52,19 +52,19 @@ import ( // // If the caller has set w's ETag header formatted per RFC 7232, section 2.3, // ServeHTTP uses it to handle requests using If-Match, If-None-Match, or If-Range. -func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReadCloser model.RangeReadCloserIF) { +func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReadCloser model.RangeReadCloserIF) error { defer RangeReadCloser.Close() setLastModified(w, modTime) done, rangeReq := checkPreconditions(w, r, modTime) if done { - return + return nil } if size < 0 { // since too many functions need file size to work, // will not implement the support of unknown file size here http.Error(w, "negative content size not supported", http.StatusInternalServerError) - return + return nil } code := http.StatusOK @@ -103,7 +103,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time fallthrough default: http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) - return + return nil } if sumRangesSize(ranges) > size { @@ -124,7 +124,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time code = http.StatusTooManyRequests } http.Error(w, err.Error(), code) - return + return nil } sendContent = reader case len(ranges) == 1: @@ -147,7 +147,7 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time code = http.StatusTooManyRequests } http.Error(w, err.Error(), code) - return + return nil } sendSize = ra.Length code = http.StatusPartialContent @@ -205,9 +205,11 @@ func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time if err == ErrExceedMaxConcurrency { code = http.StatusTooManyRequests } - http.Error(w, err.Error(), code) + w.WriteHeader(code) + return err } } + return nil } func ProcessHeader(origin, override http.Header) http.Header { result := http.Header{} diff --git a/server/common/proxy.go b/server/common/proxy.go index f9e1e4bb..ca7f6325 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -39,11 +39,10 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. return nil } else if link.RangeReadCloser != nil { attachHeader(w, file) - net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ + return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ RangeReadCloserIF: link.RangeReadCloser, Limiter: stream.ServerDownloadLimit, }) - return nil } else if link.Concurrency != 0 || link.PartSize != 0 { attachHeader(w, file) size := file.GetSize() @@ -66,11 +65,10 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. rc, err := down.Download(ctx, req) return rc, err } - net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ + return net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), &stream.RateLimitRangeReadCloser{ RangeReadCloserIF: &model.RangeReadCloser{RangeReader: rangeReader}, Limiter: stream.ServerDownloadLimit, }) - return nil } else { //transparent proxy header := net.ProcessHeader(r.Header, link.Header) @@ -90,10 +88,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. Limiter: stream.ServerDownloadLimit, Ctx: r.Context(), }) - if err != nil { - return err - } - return nil + return err } } func attachHeader(w http.ResponseWriter, file model.Obj) { @@ -133,3 +128,29 @@ func ProxyRange(link *model.Link, size int64) { link.RangeReadCloser = nil } } + +type InterceptResponseWriter struct { + http.ResponseWriter + io.Writer +} + +func (iw *InterceptResponseWriter) Write(p []byte) (int, error) { + return iw.Writer.Write(p) +} + +type WrittenResponseWriter struct { + http.ResponseWriter + written bool +} + +func (ww *WrittenResponseWriter) Write(p []byte) (int, error) { + n, err := ww.ResponseWriter.Write(p) + if !ww.written && n > 0 { + ww.written = true + } + return n, err +} + +func (ww *WrittenResponseWriter) IsWritten() bool { + return ww.written +} diff --git a/server/handles/down.go b/server/handles/down.go index 1153881f..2c5c2faf 100644 --- a/server/handles/down.go +++ b/server/handles/down.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "net/http" stdpath "path" "strconv" "strings" @@ -129,15 +128,16 @@ func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange boo if proxyRange { common.ProxyRange(link, file.GetSize()) } + Writer := &common.WrittenResponseWriter{ResponseWriter: c.Writer} //优先处理md文件 if utils.Ext(file.GetName()) == "md" && setting.GetBool(conf.FilterReadMeScripts) { - w := c.Writer buf := bytes.NewBuffer(make([]byte, 0, file.GetSize())) - err = common.Proxy(&proxyResponseWriter{ResponseWriter: w, Writer: buf}, c.Request, link, file) + w := &common.InterceptResponseWriter{ResponseWriter: Writer, Writer: buf} + err = common.Proxy(w, c.Request, link, file) if err == nil && buf.Len() > 0 { - if w.Status() < 200 || w.Status() > 300 { - w.Write(buf.Bytes()) + if c.Writer.Status() < 200 || c.Writer.Status() > 300 { + c.Writer.Write(buf.Bytes()) return } @@ -148,19 +148,23 @@ func localProxy(c *gin.Context, link *model.Link, file model.Obj, proxyRange boo buf.Reset() err = bluemonday.UGCPolicy().SanitizeReaderToWriter(&html, buf) if err == nil { - w.Header().Set("Content-Length", strconv.FormatInt(int64(buf.Len()), 10)) - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, err = utils.CopyWithBuffer(c.Writer, buf) + Writer.Header().Set("Content-Length", strconv.FormatInt(int64(buf.Len()), 10)) + Writer.Header().Set("Content-Type", "text/html; charset=utf-8") + _, err = utils.CopyWithBuffer(Writer, buf) } } } } else { - err = common.Proxy(c.Writer, c.Request, link, file) + err = common.Proxy(Writer, c.Request, link, file) } - if err != nil { - common.ErrorResp(c, err, 500, true) + if err == nil { return } + if Writer.IsWritten() { + log.Errorf("%s %s local proxy error: %+v", c.Request.Method, c.Request.URL.Path, err) + } else { + common.ErrorResp(c, err, 500, true) + } } // TODO need optimize @@ -182,12 +186,3 @@ func canProxy(storage driver.Driver, filename string) bool { } return false } - -type proxyResponseWriter struct { - http.ResponseWriter - io.Writer -} - -func (pw *proxyResponseWriter) Write(p []byte) (int, error) { - return pw.Writer.Write(p) -} diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index 1b7ec6ff..f22e15aa 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -24,7 +24,6 @@ import ( "github.com/alist-org/alist/v3/internal/sign" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" - log "github.com/sirupsen/logrus" ) type Handler struct { @@ -59,7 +58,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { status, err = h.handleOptions(brw, r) case "GET", "HEAD", "POST": useBufferedWriter = false - status, err = h.handleGetHeadPost(w, r) + Writer := &common.WrittenResponseWriter{ResponseWriter: w} + status, err = h.handleGetHeadPost(Writer, r) + if status != 0 && Writer.IsWritten() { + status = 0 + } case "DELETE": status, err = h.handleDelete(brw, r) case "PUT": @@ -247,8 +250,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta } err = common.Proxy(w, r, link, fi) if err != nil { - log.Errorf("webdav proxy error: %+v", err) - return http.StatusInternalServerError, err + return http.StatusInternalServerError, fmt.Errorf("webdav proxy error: %+v", err) } } else if storage.GetStorage().WebdavProxy() && downProxyUrl != "" { u := fmt.Sprintf("%s%s?sign=%s",