caddyhttp: Address some Go 1.20 features (#6252)

Co-authored-by: Francis Lavoie <lavofr@gmail.com>
This commit is contained in:
Matt Holt
2024-04-23 20:05:57 -04:00
committed by GitHub
parent d404005339
commit 6d97d8d87b
8 changed files with 125 additions and 16 deletions

View File

@ -226,13 +226,22 @@ func StatusCodeMatches(actual, configured int) bool {
// in the implementation of http.Dir. The root is assumed to
// be a trusted path, but reqPath is not; and the output will
// never be outside of root. The resulting path can be used
// with the local file system.
// with the local file system. If root is empty, the current
// directory is assumed. If the cleaned request path is deemed
// not local according to lexical processing (i.e. ignoring links),
// it will be rejected as unsafe and only the root will be returned.
func SanitizedPathJoin(root, reqPath string) string {
if root == "" {
root = "."
}
path := filepath.Join(root, path.Clean("/"+reqPath))
relPath := path.Clean("/" + reqPath)[1:] // clean path and trim the leading /
if !filepath.IsLocal(relPath) {
// path is unsafe (see https://github.com/golang/go/issues/56336#issuecomment-1416214885)
return root
}
path := filepath.Join(root, filepath.FromSlash(relPath))
// filepath.Join also cleans the path, and cleaning strips
// the trailing slash, so we need to re-add it afterwards.

View File

@ -3,6 +3,7 @@ package caddyhttp
import (
"net/url"
"path/filepath"
"runtime"
"testing"
)
@ -12,9 +13,10 @@ func TestSanitizedPathJoin(t *testing.T) {
// %2f = /
// %5c = \
for i, tc := range []struct {
inputRoot string
inputPath string
expect string
inputRoot string
inputPath string
expect string
expectWindows string
}{
{
inputPath: "",
@ -63,7 +65,7 @@ func TestSanitizedPathJoin(t *testing.T) {
{
inputRoot: "/a/b",
inputPath: "/%2e%2e%2f%2e%2e%2f",
expect: filepath.Join("/", "a", "b") + separator,
expect: "/a/b", // inputPath fails the IsLocal test so only the root is returned
},
{
inputRoot: "/a/b",
@ -81,9 +83,16 @@ func TestSanitizedPathJoin(t *testing.T) {
expect: filepath.Join("C:\\www", "foo", "bar"),
},
{
inputRoot: "C:\\www",
inputPath: "/D:\\foo\\bar",
expect: filepath.Join("C:\\www", "D:\\foo\\bar"),
inputRoot: "C:\\www",
inputPath: "/D:\\foo\\bar",
expect: filepath.Join("C:\\www", "D:\\foo\\bar"),
expectWindows: filepath.Join("C:\\www"), // inputPath fails IsLocal on Windows
},
{
// https://github.com/golang/go/issues/56336#issuecomment-1416214885
inputRoot: "root",
inputPath: "/a/b/../../c",
expect: filepath.Join("root", "c"),
},
} {
// we don't *need* to use an actual parsed URL, but it
@ -96,6 +105,9 @@ func TestSanitizedPathJoin(t *testing.T) {
t.Fatalf("Test %d: invalid URL: %v", i, err)
}
actual := SanitizedPathJoin(tc.inputRoot, u.Path)
if runtime.GOOS == "windows" && tc.expectWindows != "" {
tc.expect = tc.expectWindows
}
if actual != tc.expect {
t.Errorf("Test %d: SanitizedPathJoin('%s', '%s') => '%s' (expected '%s')",
i, tc.inputRoot, tc.inputPath, actual, tc.expect)

View File

@ -15,6 +15,8 @@
package requestbody
import (
"time"
"github.com/dustin/go-humanize"
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
@ -44,8 +46,30 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
}
rb.MaxSize = int64(size)
case "read_timeout":
var timeoutStr string
if !h.AllArgs(&timeoutStr) {
return nil, h.ArgErr()
}
timeout, err := time.ParseDuration(timeoutStr)
if err != nil {
return nil, h.Errf("parsing read_timeout: %v", err)
}
rb.ReadTimeout = timeout
case "write_timeout":
var timeoutStr string
if !h.AllArgs(&timeoutStr) {
return nil, h.ArgErr()
}
timeout, err := time.ParseDuration(timeoutStr)
if err != nil {
return nil, h.Errf("parsing write_timeout: %v", err)
}
rb.WriteTimeout = timeout
default:
return nil, h.Errf("unrecognized servers option '%s'", h.Val())
return nil, h.Errf("unrecognized request_body subdirective '%s'", h.Val())
}
}

View File

@ -17,6 +17,9 @@ package requestbody
import (
"io"
"net/http"
"time"
"go.uber.org/zap"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
@ -31,6 +34,14 @@ type RequestBody struct {
// The maximum number of bytes to allow reading from the body by a later handler.
// If more bytes are read, an error with HTTP status 413 is returned.
MaxSize int64 `json:"max_size,omitempty"`
// EXPERIMENTAL. Subject to change/removal.
ReadTimeout time.Duration `json:"read_timeout,omitempty"`
// EXPERIMENTAL. Subject to change/removal.
WriteTimeout time.Duration `json:"write_timeout,omitempty"`
logger *zap.Logger
}
// CaddyModule returns the Caddy module information.
@ -41,6 +52,11 @@ func (RequestBody) CaddyModule() caddy.ModuleInfo {
}
}
func (rb *RequestBody) Provision(ctx caddy.Context) error {
rb.logger = ctx.Logger()
return nil
}
func (rb RequestBody) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
if r.Body == nil {
return next.ServeHTTP(w, r)
@ -48,6 +64,20 @@ func (rb RequestBody) ServeHTTP(w http.ResponseWriter, r *http.Request, next cad
if rb.MaxSize > 0 {
r.Body = errorWrapper{http.MaxBytesReader(w, r.Body, rb.MaxSize)}
}
if rb.ReadTimeout > 0 || rb.WriteTimeout > 0 {
//nolint:bodyclose
rc := http.NewResponseController(w)
if rb.ReadTimeout > 0 {
if err := rc.SetReadDeadline(time.Now().Add(rb.ReadTimeout)); err != nil {
rb.logger.Error("could not set read deadline", zap.Error(err))
}
}
if rb.WriteTimeout > 0 {
if err := rc.SetWriteDeadline(time.Now().Add(rb.WriteTimeout)); err != nil {
rb.logger.Error("could not set write deadline", zap.Error(err))
}
}
}
return next.ServeHTTP(w, r)
}

View File

@ -2,7 +2,6 @@ package caddyhttp
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
@ -75,20 +74,19 @@ func TestResponseWriterWrapperReadFrom(t *testing.T) {
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}
fmt.Println(name)
if _, err := io.Copy(wrapped, src); err != nil {
t.Errorf("Copy() err = %v", err)
t.Errorf("%s: Copy() err = %v", name, err)
}
if got := tt.responseWriter.Written(); got != srcData {
t.Errorf("data = %q, want %q", got, srcData)
t.Errorf("%s: data = %q, want %q", name, got, srcData)
}
if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
t.Errorf("%s: ReadFrom() should have been called", name)
} else {
t.Errorf("ReadFrom() should not have been called")
t.Errorf("%s: ReadFrom() should not have been called", name)
}
}
})