diff --git a/agent/app/service/website_rewrite.go b/agent/app/service/website_rewrite.go index b2263c4900ac..af04f2e0798b 100644 --- a/agent/app/service/website_rewrite.go +++ b/agent/app/service/website_rewrite.go @@ -70,16 +70,16 @@ func (w WebsiteService) GetRewriteConfig(req request.NginxRewriteReq) (*response } } } else { - rewriteFile := fmt.Sprintf("rewrite/%s.conf", strings.ToLower(req.Name)) - contentByte, _ = nginx_conf.Rewrites.ReadFile(rewriteFile) - if contentByte == nil { - customRewriteDir := GetOpenrestyDir(DefaultRewriteDir) - safeName := path.Base(req.Name) - if safeName != req.Name || strings.Contains(safeName, "..") { - return nil, buserr.New("ErrInvalidParams") - } - customRewriteFile := path.Join(customRewriteDir, fmt.Sprintf("%s.conf", strings.ToLower(req.Name))) + safeName, err := getSafeRewriteName(req.Name) + if err != nil { + return nil, err + } + customRewriteFile := path.Join(GetOpenrestyDir(DefaultRewriteDir), fmt.Sprintf("%s.conf", safeName)) + if files.NewFileOp().Stat(customRewriteFile) { contentByte, err = files.NewFileOp().GetContent(customRewriteFile) + } else { + rewriteFile := fmt.Sprintf("rewrite/%s.conf", strings.ToLower(safeName)) + contentByte, _ = nginx_conf.Rewrites.ReadFile(rewriteFile) } } return &response.NginxRewriteRes{ @@ -95,14 +95,18 @@ func (w WebsiteService) OperateCustomRewrite(req request.CustomRewriteOperate) e return err } } - safeName := path.Base(req.Name) - if safeName != req.Name || strings.Contains(safeName, "..") { - return buserr.New("ErrInvalidParams") + safeName, err := getSafeRewriteName(req.Name) + if err != nil { + return err } - rewriteFile := path.Join(rewriteDir, fmt.Sprintf("%s.conf", req.Name)) + rewriteFile := path.Join(rewriteDir, fmt.Sprintf("%s.conf", safeName)) switch req.Operate { case "create": - if fileOp.Stat(rewriteFile) { + exist, err := customRewriteNameExist(rewriteDir, safeName) + if err != nil { + return err + } + if exist || builtinRewriteNameExist(safeName) { return buserr.New("ErrNameIsExist") } return fileOp.WriteFile(rewriteFile, strings.NewReader(req.Content), constant.DirPerm) @@ -112,6 +116,37 @@ func (w WebsiteService) OperateCustomRewrite(req request.CustomRewriteOperate) e return nil } +func getSafeRewriteName(name string) (string, error) { + safeName := path.Base(name) + if safeName != name || strings.Contains(safeName, "..") { + return "", buserr.New("ErrInvalidParams") + } + return safeName, nil +} + +func builtinRewriteNameExist(name string) bool { + rewriteFile := fmt.Sprintf("rewrite/%s.conf", strings.ToLower(name)) + contentByte, _ := nginx_conf.Rewrites.ReadFile(rewriteFile) + return contentByte != nil +} + +func customRewriteNameExist(rewriteDir, name string) (bool, error) { + entries, err := os.ReadDir(rewriteDir) + if err != nil { + return false, err + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + entryName := strings.TrimSuffix(entry.Name(), ".conf") + if strings.EqualFold(entryName, name) { + return true, nil + } + } + return false, nil +} + func (w WebsiteService) ListCustomRewrite() ([]string, error) { rewriteDir := GetOpenrestyDir(DefaultRewriteDir) fileOp := files.NewFileOp() diff --git a/agent/app/service/website_rewrite_test.go b/agent/app/service/website_rewrite_test.go new file mode 100644 index 000000000000..ffb3d577fcd9 --- /dev/null +++ b/agent/app/service/website_rewrite_test.go @@ -0,0 +1,59 @@ +package service + +import ( + "os" + "testing" +) + +func TestBuiltinRewriteNameExistIgnoresCase(t *testing.T) { + if !builtinRewriteNameExist("WordPress") { + t.Fatal("expected WordPress to match builtin wordpress rewrite") + } + if !builtinRewriteNameExist("EmpireCMS") { + t.Fatal("expected EmpireCMS to match builtin empirecms rewrite") + } + if builtinRewriteNameExist("not-exist-rewrite") { + t.Fatal("expected unknown rewrite name to be absent") + } +} + +func TestCustomRewriteNameExistIgnoresCase(t *testing.T) { + rewriteDir := t.TempDir() + if err := os.WriteFile(rewriteDir+"/MyRewrite.conf", []byte("content"), 0644); err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + want bool + }{ + {name: "MyRewrite", want: true}, + {name: "myrewrite", want: true}, + {name: "MYREWRITE", want: true}, + {name: "OtherRewrite", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := customRewriteNameExist(rewriteDir, tt.name) + if err != nil { + t.Fatal(err) + } + if got != tt.want { + t.Fatalf("customRewriteNameExist() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetSafeRewriteName(t *testing.T) { + if _, err := getSafeRewriteName("../rewrite"); err == nil { + t.Fatal("expected path traversal name to be rejected") + } + if _, err := getSafeRewriteName("rewrite/name"); err == nil { + t.Fatal("expected path separator name to be rejected") + } + if name, err := getSafeRewriteName("MyRewrite"); err != nil || name != "MyRewrite" { + t.Fatalf("getSafeRewriteName() = %q, %v; want MyRewrite, nil", name, err) + } +}