Minor enhancements/fixes to rewrite directive and template virt req's

This commit is contained in:
Matthew Holt
2019-10-16 15:18:02 -06:00
parent 2f91b44587
commit a458544d9f
3 changed files with 27 additions and 7 deletions

View File

@ -31,6 +31,9 @@ func init() {
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
var rewr Rewrite
for h.Next() {
if !h.NextArg() {
return nil, h.ArgErr()
}
rewr.URI = h.Val()
if h.NextArg() {
return nil, h.ArgErr()

View File

@ -22,6 +22,7 @@ import (
"net"
"net/http"
"path"
"strconv"
"strings"
"sync"
@ -79,8 +80,18 @@ func (c templateContext) Include(filename string, args ...interface{}) (template
// are NOT escaped, so you should only include trusted resources.
// If it is not trusted, be sure to use escaping functions yourself.
func (c templateContext) HTTPInclude(uri string) (template.HTML, error) {
if c.Req.Header.Get(recursionPreventionHeader) == "1" {
return "", fmt.Errorf("virtual request cycle")
// prevent virtual request loops by counting how many levels
// deep we are; and if we get too deep, return an error
recursionCount := 1
if numStr := c.Req.Header.Get(recursionPreventionHeader); numStr != "" {
num, err := strconv.Atoi(numStr)
if err != nil {
return "", fmt.Errorf("parsing %s: %v", recursionPreventionHeader, err)
}
if num >= 3 {
return "", fmt.Errorf("virtual request cycle")
}
recursionCount = num + 1
}
buf := bufPool.Get().(*bytes.Buffer)
@ -91,7 +102,10 @@ func (c templateContext) HTTPInclude(uri string) (template.HTML, error) {
if err != nil {
return "", err
}
virtReq.Header.Set(recursionPreventionHeader, "1")
virtReq.Host = c.Req.Host
virtReq.Header = c.Req.Header.Clone()
virtReq.Trailer = c.Req.Trailer.Clone()
virtReq.Header.Set(recursionPreventionHeader, strconv.Itoa(recursionCount))
vrw := &virtualResponseWriter{body: buf, header: make(http.Header)}
server := c.Req.Context().Value(caddyhttp.ServerCtxKey).(http.Handler)