diff --git a/internal/sign/archive.go b/internal/sign/archive.go new file mode 100644 index 00000000..26a2c208 --- /dev/null +++ b/internal/sign/archive.go @@ -0,0 +1,41 @@ +package sign + +import ( + "sync" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/sign" +) + +var onceArchive sync.Once +var instanceArchive sign.Sign + +func SignArchive(data string) string { + expire := setting.GetInt(conf.LinkExpiration, 0) + if expire == 0 { + return NotExpiredArchive(data) + } else { + return WithDurationArchive(data, time.Duration(expire)*time.Hour) + } +} + +func WithDurationArchive(data string, d time.Duration) string { + onceArchive.Do(InstanceArchive) + return instanceArchive.Sign(data, time.Now().Add(d).Unix()) +} + +func NotExpiredArchive(data string) string { + onceArchive.Do(InstanceArchive) + return instanceArchive.Sign(data, 0) +} + +func VerifyArchive(data string, sign string) error { + onceArchive.Do(InstanceArchive) + return instanceArchive.Verify(data, sign) +} + +func InstanceArchive() { + instanceArchive = sign.NewHMACSign([]byte(setting.GetStr(conf.Token) + "-archive")) +} diff --git a/server/debug.go b/server/debug.go index 081ef8c3..a4242abd 100644 --- a/server/debug.go +++ b/server/debug.go @@ -5,6 +5,7 @@ import ( _ "net/http/pprof" "runtime" + "github.com/alist-org/alist/v3/internal/sign" "github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/middlewares" "github.com/gin-gonic/gin" @@ -15,7 +16,7 @@ func _pprof(g *gin.RouterGroup) { } func debug(g *gin.RouterGroup) { - g.GET("/path/*path", middlewares.Down, func(ctx *gin.Context) { + g.GET("/path/*path", middlewares.Down(sign.Verify), func(ctx *gin.Context) { rawPath := ctx.MustGet("path").(string) ctx.JSON(200, gin.H{ "path": rawPath, diff --git a/server/handles/archive.go b/server/handles/archive.go index fab3916e..4ec933e1 100644 --- a/server/handles/archive.go +++ b/server/handles/archive.go @@ -120,7 +120,7 @@ func FsArchiveMeta(c *gin.Context) { } s := "" if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) { - s = sign.Sign(reqPath) + s = sign.SignArchive(reqPath) } api := "/ae" if ret.DriverProviding { diff --git a/server/middlewares/down.go b/server/middlewares/down.go index 05e9dc85..d015672d 100644 --- a/server/middlewares/down.go +++ b/server/middlewares/down.go @@ -9,35 +9,36 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" - "github.com/alist-org/alist/v3/internal/sign" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/gin-gonic/gin" "github.com/pkg/errors" ) -func Down(c *gin.Context) { - rawPath := parsePath(c.Param("path")) - c.Set("path", rawPath) - meta, err := op.GetNearestMeta(rawPath) - if err != nil { - if !errors.Is(errors.Cause(err), errs.MetaNotFound) { - common.ErrorResp(c, err, 500, true) - return - } - } - c.Set("meta", meta) - // verify sign - if needSign(meta, rawPath) { - s := c.Query("sign") - err = sign.Verify(rawPath, strings.TrimSuffix(s, "/")) +func Down(verifyFunc func(string, string) error) func(c *gin.Context) { + return func(c *gin.Context) { + rawPath := parsePath(c.Param("path")) + c.Set("path", rawPath) + meta, err := op.GetNearestMeta(rawPath) if err != nil { - common.ErrorResp(c, err, 401) - c.Abort() - return + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } } + c.Set("meta", meta) + // verify sign + if needSign(meta, rawPath) { + s := c.Query("sign") + err = verifyFunc(rawPath, strings.TrimSuffix(s, "/")) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + } + c.Next() } - c.Next() } // TODO: implement diff --git a/server/router.go b/server/router.go index 830051d8..2dd6ee88 100644 --- a/server/router.go +++ b/server/router.go @@ -4,6 +4,7 @@ import ( "github.com/alist-org/alist/v3/cmd/flags" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/message" + "github.com/alist-org/alist/v3/internal/sign" "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" @@ -40,16 +41,18 @@ func Init(e *gin.Engine) { S3(g.Group("/s3")) downloadLimiter := middlewares.DownloadRateLimiter(stream.ClientDownloadLimit) - g.GET("/d/*path", middlewares.Down, downloadLimiter, handles.Down) - g.GET("/p/*path", middlewares.Down, downloadLimiter, handles.Proxy) - g.HEAD("/d/*path", middlewares.Down, handles.Down) - g.HEAD("/p/*path", middlewares.Down, handles.Proxy) - g.GET("/ad/*path", middlewares.Down, downloadLimiter, handles.ArchiveDown) - g.GET("/ap/*path", middlewares.Down, downloadLimiter, handles.ArchiveProxy) - g.GET("/ae/*path", middlewares.Down, downloadLimiter, handles.ArchiveInternalExtract) - g.HEAD("/ad/*path", middlewares.Down, handles.ArchiveDown) - g.HEAD("/ap/*path", middlewares.Down, handles.ArchiveProxy) - g.HEAD("/ae/*path", middlewares.Down, handles.ArchiveInternalExtract) + signCheck := middlewares.Down(sign.Verify) + g.GET("/d/*path", signCheck, downloadLimiter, handles.Down) + g.GET("/p/*path", signCheck, downloadLimiter, handles.Proxy) + g.HEAD("/d/*path", signCheck, handles.Down) + g.HEAD("/p/*path", signCheck, handles.Proxy) + archiveSignCheck := middlewares.Down(sign.VerifyArchive) + g.GET("/ad/*path", archiveSignCheck, downloadLimiter, handles.ArchiveDown) + g.GET("/ap/*path", archiveSignCheck, downloadLimiter, handles.ArchiveProxy) + g.GET("/ae/*path", archiveSignCheck, downloadLimiter, handles.ArchiveInternalExtract) + g.HEAD("/ad/*path", archiveSignCheck, handles.ArchiveDown) + g.HEAD("/ap/*path", archiveSignCheck, handles.ArchiveProxy) + g.HEAD("/ae/*path", archiveSignCheck, handles.ArchiveInternalExtract) api := g.Group("/api") auth := api.Group("", middlewares.Auth)