Skip to content

Commit

Permalink
code: global/up/down ["*"] allowed extemsions.
Browse files Browse the repository at this point in the history
  • Loading branch information
PlusOne committed Dec 30, 2024
1 parent 5e41c0c commit e5b5c61
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 22 deletions.
79 changes: 64 additions & 15 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,21 @@ func readConfig(configFilename string, conf *Config) error {
return fmt.Errorf("invalid bind_ip '%s'", conf.Server.BindIP)
}
}

// If uploads.allowedextensions is empty, inherit from server.globalextensions
if len(conf.Uploads.AllowedExtensions) == 0 {
conf.Uploads.AllowedExtensions = conf.Server.GlobalExtensions
} else if len(conf.Uploads.AllowedExtensions) == 1 && conf.Uploads.AllowedExtensions[0] == "*" {
conf.Uploads.AllowedExtensions = nil // nil signals all extensions allowed
}

// If downloads.allowedextensions is empty, inherit from server.globalextensions
if len(conf.Downloads.AllowedExtensions) == 0 {
conf.Downloads.AllowedExtensions = conf.Server.GlobalExtensions
} else if len(conf.Downloads.AllowedExtensions) == 1 && conf.Downloads.AllowedExtensions[0] == "*" {
conf.Downloads.AllowedExtensions = nil // nil signals all extensions allowed
}

return nil
}

Expand Down Expand Up @@ -1034,6 +1049,48 @@ func validateConfig(conf *Config) error {

// Additional configuration validations can be added here

validateAllowedExtensions := func(exts []string) error {
if exts == nil {
return nil // All extensions allowed
}
for _, ext := range exts {
if ext != "*" && !strings.HasPrefix(ext, ".") {
return fmt.Errorf("invalid extension '%s' (must start with '.' or be '*')", ext)
}
}
return nil
}

// Validate global, uploads, and downloads extensions
if err := validateAllowedExtensions(conf.Server.GlobalExtensions); err != nil {
return err
}
if err := validateAllowedExtensions(conf.Uploads.AllowedExtensions); err != nil {
return err
}
if err := validateAllowedExtensions(conf.Downloads.AllowedExtensions); err != nil {
return err
}

// Prevent '*' mixed with other extensions
hasWildcard := func(exts []string) bool {
for _, e := range exts {
if e == "*" {
return true
}
}
return false
}
if len(conf.Server.GlobalExtensions) > 1 && hasWildcard(conf.Server.GlobalExtensions) {
return fmt.Errorf("server.globalextensions cannot mix '*' with other entries")
}
if len(conf.Uploads.AllowedExtensions) > 1 && hasWildcard(conf.Uploads.AllowedExtensions) {
return fmt.Errorf("uploads.allowedextensions cannot mix '*' with other entries")
}
if len(conf.Downloads.AllowedExtensions) > 1 && hasWildcard(conf.Downloads.AllowedExtensions) {
return fmt.Errorf("downloads.allowedextensions cannot mix '*' with other entries")
}

return nil
}

Expand Down Expand Up @@ -1252,21 +1309,13 @@ func fileExists(filePath string) (bool, int64) {
return !fileInfo.IsDir(), fileInfo.Size()
}

func isExtensionAllowed(filename string) bool {
var allowedExtensions []string
if len(conf.Server.GlobalExtensions) > 0 {
allowedExtensions = conf.Server.GlobalExtensions
} else {
allowedExtensions = append(conf.Uploads.AllowedExtensions, conf.Downloads.AllowedExtensions...)
func isExtensionAllowed(filename string, allowedExts []string) bool {
if allowedExts == nil {
return true // all allowed
}

if len(allowedExtensions) == 0 {
return true
}

ext := strings.ToLower(filepath.Ext(filename))
for _, allowedExt := range allowedExtensions {
if strings.ToLower(allowedExt) == ext {
for _, allowed := range allowedExts {
if strings.EqualFold(ext, allowed) {
return true
}
}
Expand Down Expand Up @@ -1729,7 +1778,7 @@ func handleUpload(w http.ResponseWriter, r *http.Request, absFilename, fileStore
return
}

if !isExtensionAllowed(fileStorePath) {
if !isExtensionAllowed(fileStorePath, conf.Uploads.AllowedExtensions) {
log.Warn("Invalid file extension")
http.Error(w, "Invalid file extension", http.StatusBadRequest)
uploadErrorsTotal.Inc()
Expand Down Expand Up @@ -2364,7 +2413,7 @@ func handleMultipartUpload(w http.ResponseWriter, r *http.Request, absFilename s
}
defer file.Close()

if !isExtensionAllowed(handler.Filename) {
if !isExtensionAllowed(handler.Filename, conf.Uploads.AllowedExtensions) {
log.WithFields(logrus.Fields{"filename": handler.Filename, "extension": filepath.Ext(handler.Filename)}).Warn("Attempted upload with disallowed file extension")
http.Error(w, "Disallowed file extension. Allowed: "+strings.Join(conf.Uploads.AllowedExtensions, ", "), http.StatusForbidden)
uploadErrorsTotal.Inc()
Expand Down
9 changes: 2 additions & 7 deletions config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,13 @@ maxversions = 5
resumableuploadsenabled = true
chunkeduploadsenabled = true
chunksize = "8192"
allowedextensions = [".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".svg", ".webp", ".wav", ".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v", ".3gp", ".3g2", ".mp3", ".ogg"]
allowedextensions = ["*"] # Use ["*"] to allow all or specify extensions

[downloads]
resumabledownloadsenabled = true
chunkeddownloadsenabled = true
chunksize = "8192"
allowedextensions = [
".txt", ".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp",
".tiff", ".svg", ".webp", ".wav", ".mp4", ".avi", ".mkv",
".mov", ".wmv", ".flv", ".webm", ".mpeg", ".mpg", ".m4v",
".3gp", ".3g2", ".mp3", ".ogg"
]
allowedextensions = [".jpg", ".png"] # Restricts downloads to specific types

[clamav]
clamavenabled = true
Expand Down

0 comments on commit e5b5c61

Please sign in to comment.