diff --git a/cmd/server/main.go b/cmd/server/main.go index 6195ed9..7115c9d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -4,11 +4,11 @@ package main import ( "bufio" - "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/hex" + "encoding/json" // Added for JSON handling "flag" "fmt" @@ -34,6 +34,7 @@ import ( "github.com/disintegration/imaging" "github.com/dutchcoders/go-clamd" // ClamAV integration + "github.com/fsnotify/fsnotify" // Added for directory monitoring "github.com/go-redis/redis/v8" // Redis integration "github.com/patrickmn/go-cache" "github.com/prometheus/client_golang/prometheus" @@ -137,6 +138,7 @@ type ServerConfig struct { DeduplicationEnabled bool `mapstructure:"deduplicationenabled"` Logging LoggingConfig `mapstructure:"logging"` GlobalExtensions []string `mapstructure:"globalextensions"` + BindIP string `mapstructure:"bind_ip"` } type DeduplicationConfig struct { @@ -206,6 +208,10 @@ type RedisConfig struct { type WorkersConfig struct { NumWorkers int `mapstructure:"numworkers"` UploadQueueSize int `mapstructure:"uploadqueuesize"` + MaxConcurrentOperations int `mapstructure:"max_concurrent_operations"` + NetworkEventBuffer int `mapstructure:"network_event_buffer"` + PerformanceMonitorInterval string `mapstructure:"performance_monitor_interval"` + MetricsUpdateInterval string `mapstructure:"metrics_update_interval"` } type FileConfig struct { @@ -216,6 +222,14 @@ type BuildConfig struct { Version string `mapstructure:"version"` } +// Step 1: Define PrecacheConfig +type PrecacheConfig struct { + RedisEnabled bool `mapstructure:"redisEnabled"` + RedisAddr string `mapstructure:"redisAddr"` + StaticIndexFile string `mapstructure:"staticIndexFile"` +} + +// Step 2: Update Config struct to include Precache type Config struct { Server ServerConfig `mapstructure:"server"` Logging LoggingConfig `mapstructure:"logging"` @@ -232,6 +246,7 @@ type Config struct { Workers WorkersConfig `mapstructure:"workers"` File FileConfig `mapstructure:"file"` Build BuildConfig `mapstructure:"build"` + Precache PrecacheConfig `mapstructure:"precache"` } type UploadTask struct { @@ -252,7 +267,9 @@ type NetworkEvent struct { // Add a new field to store the creation date of files type FileMetadata struct { - CreationDate time.Time `json:"creationDate"` + CreationDate time.Time + FilePath string + FileInfo os.FileInfo } var ( @@ -283,7 +300,7 @@ var ( downloadSizeBytes prometheus.Histogram scanQueue chan ScanTask - ScanWorkers = 5 + ScanWorkers = 0 thumbnailProcessedTotal prometheus.Counter thumbnailProcessingErrors prometheus.Counter @@ -303,9 +320,7 @@ var bufferPool = sync.Pool{ }, } -const maxConcurrentOperations = 10 - -var semaphore = make(chan struct{}, maxConcurrentOperations) +var semaphore chan struct{} var logMessages []string var logMu sync.Mutex @@ -423,7 +438,7 @@ func main() { // Set log level based on configuration level, err := logrus.ParseLevel(conf.Logging.Level) - if err != nil { + if (err != nil) { log.Warnf("Invalid log level '%s', defaulting to 'info'", conf.Logging.Level) level = logrus.InfoLevel } @@ -512,7 +527,8 @@ func main() { uploadQueue = make(chan UploadTask, conf.Workers.UploadQueueSize) scanQueue = make(chan ScanTask, conf.Workers.UploadQueueSize) - networkEvents = make(chan NetworkEvent, 100) + networkEvents = make(chan NetworkEvent, conf.Workers.NetworkEventBuffer) + semaphore = make(chan struct{}, conf.Workers.MaxConcurrentOperations) log.Info("Upload, scan, and network event channels initialized.") ctx, cancel := context.WithCancel(context.Background()) @@ -569,8 +585,13 @@ func main() { log.Fatalf("Invalid IdleTimeout: %v", err) } + listenAddress := conf.Server.BindIP + ":" + conf.Server.ListenPort + if net.ParseIP(conf.Server.BindIP) != nil && strings.Contains(conf.Server.BindIP, ":") { + // Enclose IPv6 in brackets + listenAddress = "[" + conf.Server.BindIP + "]:" + conf.Server.ListenPort + } server := &http.Server{ - Addr: ":" + conf.Server.ListenPort, + Addr: listenAddress, Handler: router, ReadTimeout: readTimeout, WriteTimeout: writeTimeout, @@ -602,7 +623,7 @@ func main() { versionString = conf.Build.Version log.Infof("Running version: %s", versionString) - log.Infof("Starting HMAC file server %s...", versionString) + log.Infof("Starting HMAC file server %s on %s...", versionString, listenAddress) if conf.Server.UnixSocket { if err := os.RemoveAll(conf.Server.ListenPort); err != nil { log.Fatalf("Failed to remove existing Unix socket: %v", err) @@ -664,7 +685,7 @@ func initializeWorkerSettings(server *ServerConfig, workers *WorkersConfig, clam } func monitorWorkerPerformance(ctx context.Context, server *ServerConfig, w *WorkersConfig, clamav *ClamAVConfig) { - ticker := time.NewTicker(5 * time.Minute) + ticker := time.NewTicker(parseDuration(conf.Workers.PerformanceMonitorInterval)) defer ticker.Stop() for { @@ -695,9 +716,33 @@ func readConfig(configFilename string, conf *Config) error { if err := viper.Unmarshal(conf); err != nil { return fmt.Errorf("unable to decode config into struct: %v", err) } + if conf.Server.BindIP == "" { + conf.Server.BindIP = "0.0.0.0" + log.Warn("bind_ip not set. Defaulting to 0.0.0.0") + } else { + if net.ParseIP(conf.Server.BindIP) == nil { + return fmt.Errorf("invalid bind_ip '%s'", conf.Server.BindIP) + } + } + + // If uploads.allowedextensions is empty, inherit from server.globalextensions + if len(conf.Uploads.AllowedExtensions) == 0 { + conf.Uploads.AllowedExtensions = conf.Server.GlobalExtensions + } else if len(conf.Uploads.AllowedExtensions) == 1 && conf.Uploads.AllowedExtensions[0] == "*" { + conf.Uploads.AllowedExtensions = nil // nil signals all extensions allowed + } + + // If downloads.allowedextensions is empty, inherit from server.globalextensions + if len(conf.Downloads.AllowedExtensions) == 0 { + conf.Downloads.AllowedExtensions = conf.Server.GlobalExtensions + } else if len(conf.Downloads.AllowedExtensions) == 1 && conf.Downloads.AllowedExtensions[0] == "*" { + conf.Downloads.AllowedExtensions = nil // nil signals all extensions allowed + } + return nil } +// Step 3: Set default values for precache in setDefaults func setDefaults() { viper.SetDefault("server.listenport", "8080") viper.SetDefault("server.unixsocket", false) @@ -713,6 +758,7 @@ func setDefaults() { viper.SetDefault("server.loggingjson", false) viper.SetDefault("server.filettlenabled", true) viper.SetDefault("server.deduplicationenabled", true) + viper.SetDefault("server.bind_ip", "0.0.0.0") viper.SetDefault("timeouts.readtimeout", "4800s") viper.SetDefault("timeouts.writetimeout", "4800s") @@ -744,6 +790,10 @@ func setDefaults() { viper.SetDefault("workers.numworkers", 4) viper.SetDefault("workers.uploadqueuesize", 50) + viper.SetDefault("workers.max_concurrent_operations", 10) + viper.SetDefault("workers.network_event_buffer", 100) + viper.SetDefault("workers.performance_monitor_interval", "5m") + viper.SetDefault("workers.metrics_update_interval", "10s") viper.SetDefault("deduplication.enabled", true) @@ -761,6 +811,11 @@ func setDefaults() { viper.SetDefault("logging.max_backups", 7) viper.SetDefault("logging.max_age", 30) viper.SetDefault("logging.compress", true) + + // Step 3: Set default values for precache in setDefaults + viper.SetDefault("precache.redisEnabled", true) + viper.SetDefault("precache.redisAddr", "localhost:6379") + viper.SetDefault("precache.staticIndexFile", "./static_index.json") } func validateConfig(conf *Config) error { @@ -818,7 +873,7 @@ func validateConfig(conf *Config) error { // Validate Uploads Configuration if conf.Uploads.ResumableUploadsEnabled { if conf.Uploads.ChunkSize == "" { - return fmt.Errorf("uploads.chunkSize must be set when resumable uploads are enabled") + return fmt.Errorf("uploads.chunksize must be set when resumable uploads are enabled") } if len(conf.Uploads.AllowedExtensions) == 0 { return fmt.Errorf("uploads.allowedextensions must have at least one extension") @@ -832,6 +887,18 @@ func validateConfig(conf *Config) error { if conf.Workers.UploadQueueSize <= 0 { return fmt.Errorf("workers.uploadQueueSize must be greater than 0") } + if conf.Workers.MaxConcurrentOperations <= 0 { + return fmt.Errorf("invalid max_concurrent_operations") + } + if conf.Workers.NetworkEventBuffer <= 0 { + return fmt.Errorf("invalid network_event_buffer") + } + if conf.Workers.PerformanceMonitorInterval == "" { + return fmt.Errorf("invalid performance_monitor_interval") + } + if conf.Workers.MetricsUpdateInterval == "" { + return fmt.Errorf("invalid metrics_update_interval") + } // Validate ClamAV Configuration if conf.ClamAV.ClamAVEnabled { @@ -982,6 +1049,48 @@ func validateConfig(conf *Config) error { // Additional configuration validations can be added here + validateAllowedExtensions := func(exts []string) error { + if exts == nil { + return nil // All extensions allowed + } + for _, ext := range exts { + if ext != "*" && !strings.HasPrefix(ext, ".") { + return fmt.Errorf("invalid extension '%s' (must start with '.' or be '*')", ext) + } + } + return nil + } + + // Validate global, uploads, and downloads extensions + if err := validateAllowedExtensions(conf.Server.GlobalExtensions); err != nil { + return err + } + if err := validateAllowedExtensions(conf.Uploads.AllowedExtensions); err != nil { + return err + } + if err := validateAllowedExtensions(conf.Downloads.AllowedExtensions); err != nil { + return err + } + + // Prevent '*' mixed with other extensions + hasWildcard := func(exts []string) bool { + for _, e := range exts { + if e == "*" { + return true + } + } + return false + } + if len(conf.Server.GlobalExtensions) > 1 && hasWildcard(conf.Server.GlobalExtensions) { + return fmt.Errorf("server.globalextensions cannot mix '*' with other entries") + } + if len(conf.Uploads.AllowedExtensions) > 1 && hasWildcard(conf.Uploads.AllowedExtensions) { + return fmt.Errorf("uploads.allowedextensions cannot mix '*' with other entries") + } + if len(conf.Downloads.AllowedExtensions) > 1 && hasWildcard(conf.Downloads.AllowedExtensions) { + return fmt.Errorf("downloads.allowedextensions cannot mix '*' with other entries") + } + return nil } @@ -1161,7 +1270,7 @@ func initMetrics() { } func updateSystemMetrics(ctx context.Context) { - ticker := time.NewTicker(10 * time.Second) + ticker := time.NewTicker(parseDuration(conf.Workers.MetricsUpdateInterval)) defer ticker.Stop() for { @@ -1200,21 +1309,13 @@ func fileExists(filePath string) (bool, int64) { return !fileInfo.IsDir(), fileInfo.Size() } -func isExtensionAllowed(filename string) bool { - var allowedExtensions []string - if len(conf.Server.GlobalExtensions) > 0 { - allowedExtensions = conf.Server.GlobalExtensions - } else { - allowedExtensions = append(conf.Uploads.AllowedExtensions, conf.Downloads.AllowedExtensions...) +func isExtensionAllowed(filename string, allowedExts []string) bool { + if allowedExts == nil { + return true // all allowed } - - if len(allowedExtensions) == 0 { - return true - } - ext := strings.ToLower(filepath.Ext(filename)) - for _, allowedExt := range allowedExtensions { - if strings.ToLower(allowedExt) == ext { + for _, allowed := range allowedExts { + if strings.EqualFold(ext, allowed) { return true } } @@ -1506,14 +1607,19 @@ func initializeScanWorkerPool(ctx context.Context) { func setupRouter() http.Handler { mux := http.NewServeMux() + + // Thumbnails endpoint + mux.HandleFunc("/thumbnails", handleThumbnails) + + // Existing handlers mux.HandleFunc("/", handleRequest) - if (conf.Server.MetricsEnabled) { + + if conf.Server.MetricsEnabled { mux.Handle("/metrics", promhttp.Handler()) } handler := loggingMiddleware(mux) handler = recoveryMiddleware(handler) handler = corsMiddleware(handler) - handler = HMACAuthenticationMiddleware(handler) // Add HMAC middleware here return handler } @@ -1549,37 +1655,6 @@ func corsMiddleware(next http.Handler) http.Handler { }) } -// HMACAuthenticationMiddleware verifies the HMAC signature of incoming requests -func HMACAuthenticationMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - signature := r.Header.Get("X-HMAC-Signature") - if signature == "" { - http.Error(w, "Missing HMAC signature", http.StatusUnauthorized) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Unable to read request body", http.StatusBadRequest) - return - } - // Restore the Body so the next handler can read it - r.Body = io.NopCloser(bytes.NewBuffer(body)) - - mac := hmac.New(sha256.New, []byte(conf.Security.Secret)) - mac.Write(body) - expectedMAC := mac.Sum(nil) - expectedSignature := hex.EncodeToString(expectedMAC) - - if !hmac.Equal([]byte(expectedSignature), []byte(signature)) { - http.Error(w, "Invalid HMAC signature", http.StatusUnauthorized) - return - } - - next.ServeHTTP(w, r) - }) -} - func handleRequest(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost && strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") { absFilename, err := sanitizeFilePath(conf.Server.StoragePath, strings.TrimPrefix(r.URL.Path, "/")) @@ -1615,7 +1690,8 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { log.WithFields(logrus.Fields{"method": r.Method, "url": r.URL.String(), "remote": clientIP}).Info("Incoming request") p := r.URL.Path - if _, err := url.ParseQuery(r.URL.RawQuery); err != nil { + a, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { log.Warn("Failed to parse query parameters") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return @@ -1643,7 +1719,7 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodPut: - handleUpload(w, absFilename) + handleUpload(w, r, absFilename, fileStorePath, a) case http.MethodHead, http.MethodGet: handleDownload(w, r, absFilename, fileStorePath) case http.MethodOptions: @@ -1657,26 +1733,151 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { } // handleUpload handles PUT requests for file uploads -func handleUpload(w http.ResponseWriter, absFilename string) { - log.Infof("Starting handleUpload for file: %s", absFilename) +func handleUpload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string, a url.Values) { + log.Infof("Using storage path: %s", conf.Server.StoragePath) + + // HMAC validation + var protocolVersion string + if a.Get("v2") != "" { + protocolVersion = "v2" + } else if a.Get("token") != "" { + protocolVersion = "token" + } else if a.Get("v") != "" { + protocolVersion = "v" + } else { + log.Warn("No HMAC attached to URL.") + http.Error(w, "No HMAC attached to URL. Expecting 'v', 'v2', or 'token' parameter as MAC", http.StatusForbidden) + return + } + + mac := hmac.New(sha256.New, []byte(conf.Security.Secret)) + + if protocolVersion == "v" { + mac.Write([]byte(fileStorePath + "\x20" + strconv.FormatInt(r.ContentLength, 10))) + } else { + contentType := mime.TypeByExtension(filepath.Ext(fileStorePath)) + if contentType == "" { + contentType = "application/octet-stream" + } + mac.Write([]byte(fileStorePath + "\x00" + strconv.FormatInt(r.ContentLength, 10) + "\x00" + contentType)) + } + + calculatedMAC := mac.Sum(nil) + + providedMACHex := a.Get(protocolVersion) + providedMAC, err := hex.DecodeString(providedMACHex) + if err != nil { + log.Warn("Invalid MAC encoding") + http.Error(w, "Invalid MAC encoding", http.StatusForbidden) + return + } + + if !hmac.Equal(calculatedMAC, providedMAC) { + log.Warn("Invalid MAC") + http.Error(w, "Invalid MAC", http.StatusForbidden) + return + } + + if !isExtensionAllowed(fileStorePath, conf.Uploads.AllowedExtensions) { + log.Warn("Invalid file extension") + http.Error(w, "Invalid file extension", http.StatusBadRequest) + uploadErrorsTotal.Inc() + return + } + + minFreeBytes, err := parseSize(conf.Server.MinFreeBytes) + if err != nil { + log.Fatalf("Invalid MinFreeBytes: %v", err) + } + err = checkStorageSpace(conf.Server.StoragePath, minFreeBytes) + if err != nil { + log.Warn("Not enough free space") + http.Error(w, "Not enough free space", http.StatusInsufficientStorage) + uploadErrorsTotal.Inc() + return + } - // ...existing code... + // Create temp file and write the uploaded data + tempFilename := absFilename + ".tmp" + err = createFile(tempFilename, r) + if err != nil { + log.WithFields(logrus.Fields{ + "filename": absFilename, + }).WithError(err).Error("Error creating temp file") + http.Error(w, "Error writing temp file", http.StatusInternalServerError) + return + } - tempFilename := absFilename + ".tmp" // Declare tempFilename - err := os.Rename(tempFilename, absFilename) // Declare and assign err + // Move temp file to final destination + err = os.Rename(tempFilename, absFilename) if err != nil { - // ...existing error handling... - w.WriteHeader(http.StatusInternalServerError) - log.Errorf("Error renaming file: %v", err) - return // Add return to prevent unreachable code + log.Errorf("Rename failed for %s: %v", absFilename, err) + os.Remove(tempFilename) + http.Error(w, "Error moving file to final destination", http.StatusInternalServerError) + return } - // Respond with 201 Created once + // Respond with 201 Created immediately w.WriteHeader(http.StatusCreated) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } log.Infof("Responded with 201 Created for file: %s", absFilename) // Asynchronous processing in the background - // ...existing code... + go func() { + var logMessages []string + + // ClamAV scanning + if conf.ClamAV.ClamAVEnabled && shouldScanFile(absFilename) { + err := scanFileWithClamAV(absFilename) + if err != nil { + logMessages = append(logMessages, fmt.Sprintf("ClamAV failed for %s: %v", absFilename, err)) + for _, msg := range logMessages { + log.Info(msg) + } + return + } else { + logMessages = append(logMessages, fmt.Sprintf("ClamAV scan passed for file: %s", absFilename)) + } + } + + // Deduplication + if conf.Redis.RedisEnabled && conf.Server.DeduplicationEnabled { + err := handleDeduplication(context.Background(), absFilename) + if err != nil { + log.Errorf("Deduplication failed for %s: %v", absFilename, err) + os.Remove(absFilename) + uploadErrorsTotal.Inc() + return + } else { + logMessages = append(logMessages, fmt.Sprintf("Deduplication handled successfully for file: %s", absFilename)) + } + } + + // Versioning + if conf.Versioning.EnableVersioning { + if exists, _ := fileExists(absFilename); exists { + err := versionFile(absFilename) + if err != nil { + log.Errorf("Versioning failed for %s: %v", absFilename, err) + os.Remove(absFilename) + uploadErrorsTotal.Inc() + return + } else { + logMessages = append(logMessages, fmt.Sprintf("File versioned successfully: %s", absFilename)) + } + } + } + + logMessages = append(logMessages, fmt.Sprintf("Processing completed successfully for %s", absFilename)) + uploadsTotal.Inc() + + // Log all messages at once + for _, msg := range logMessages { + log.Info(msg) + } + }() } func handleDownload(w http.ResponseWriter, r *http.Request, absFilename, fileStorePath string) { @@ -2212,7 +2413,7 @@ func handleMultipartUpload(w http.ResponseWriter, r *http.Request, absFilename s } defer file.Close() - if !isExtensionAllowed(handler.Filename) { + if !isExtensionAllowed(handler.Filename, conf.Uploads.AllowedExtensions) { log.WithFields(logrus.Fields{"filename": handler.Filename, "extension": filepath.Ext(handler.Filename)}).Warn("Attempted upload with disallowed file extension") http.Error(w, "Disallowed file extension. Allowed: "+strings.Join(conf.Uploads.AllowedExtensions, ", "), http.StatusForbidden) uploadErrorsTotal.Inc() @@ -2477,25 +2678,156 @@ func handleISOContainer(absFilename string) error { return nil } +// Step 4: Implement precacheStoragePath function func precacheStoragePath(dir string) error { - return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + // Attempt to load directory index from Redis + if conf.Precache.RedisEnabled && redisConnected { + data, err := redisClient.Get(context.Background(), "directory_index").Result() + if err == nil { + var index []FileMetadata + err = json.Unmarshal([]byte(data), &index) // Added the target variable + if err == nil { + for _, metadata := range index { + fileInfoCache.Set(metadata.FilePath, metadata.FileInfo, cache.DefaultExpiration) + fileMetadataCache.Set(metadata.FilePath, metadata, cache.DefaultExpiration) + } + log.Info("Loaded directory index from Redis") + return nil + } + log.Warn("Failed to unmarshal directory index from Redis") + } else { + log.Warn("Failed to load directory index from Redis:", err) + } + } + + // Attempt to load directory index from static index file + staticIndexFile := conf.Precache.StaticIndexFile + file, err := os.Open(staticIndexFile) + if err == nil { + defer file.Close() + var index []FileMetadata + decoder := json.NewDecoder(file) + err = decoder.Decode(&index) + if err == nil { + for _, metadata := range index { + fileInfoCache.Set(metadata.FilePath, metadata.FileInfo, cache.DefaultExpiration) + fileMetadataCache.Set(metadata.FilePath, metadata, cache.DefaultExpiration) + } + log.Info("Loaded directory index from static index file") + return nil + } + log.Warn("Failed to decode static index file:", err) + } else { + log.Warn("Static index file not found:", err) + } + + // Perform full directory scan + var index []FileMetadata + err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { if err != nil { - log.Warnf("Error accessing path %s: %v", path, err) - return nil // Continue walking + return err } if !info.IsDir() { + metadata := FileMetadata{ + FilePath: path, + FileInfo: info, + } + index = append(index, metadata) fileInfoCache.Set(path, info, cache.DefaultExpiration) - fileMetadataCache.Set(path, FileMetadata{CreationDate: info.ModTime()}, cache.DefaultExpiration) - log.Debugf("Cached file info and metadata for %s", path) + fileMetadataCache.Set(path, metadata, cache.DefaultExpiration) } return nil }) + if err != nil { + return err + } + log.Info("Performed full directory scan and populated caches") + + // Save index to Redis + if conf.Precache.RedisEnabled && redisConnected { + data, err := json.Marshal(index) + if err == nil { + err = redisClient.Set(context.Background(), "directory_index", data, 0).Err() + if err == nil { + log.Info("Saved directory index to Redis") + } else { + log.Warn("Failed to save directory index to Redis:", err) + } + } else { + log.Warn("Failed to marshal directory index for Redis:", err) + } + } + + // Save index to static index file + file, err = os.Create(staticIndexFile) + if err == nil { + defer file.Close() + encoder := json.NewEncoder(file) + err = encoder.Encode(index) + if err == nil { + log.Info("Saved directory index to static index file") + } else { + log.Warn("Failed to encode directory index to static index file:", err) + } + } else { + log.Warn("Failed to create static index file:", err) + } + + // Monitor directory for changes + go monitorDirectoryChanges(dir) + + return nil +} + +// Step 5: Implement monitorDirectoryChanges function +func monitorDirectoryChanges(dir string) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Error("Failed to create directory watcher:", err) + return + } + defer watcher.Close() + + err = watcher.Add(dir) + if err != nil { + log.Error("Failed to add directory to watcher:", err) + return + } + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Create == fsnotify.Create || event.Op&fsnotify.Remove == fsnotify.Remove { + log.Infof("Directory change detected: %s", event.String()) + // Update caches and indices + err := precacheStoragePath(dir) + if err != nil { + log.Error("Failed to update precache after directory change:", err) + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Error("Directory watcher error:", err) + } + } } -// generateThumbnail exclusively uses imaging and logs additional details func generateThumbnail(originalPath, thumbnailDir, size string) error { - log.Infof("Starting thumbnail generation for: %s; target directory: %s; requested size: %s", - originalPath, thumbnailDir, size) + // Check if thumbnail generation is enabled + if (!conf.Thumbnails.Enabled) { + return nil + } + + // Check if the file is an image + if !isImageFile(originalPath) { + log.Infof("File %s is not an image. Skipping thumbnail generation.", originalPath) + return nil + } // Parse the size (e.g., "200x200") dimensions := strings.Split(size, "x") @@ -2511,22 +2843,53 @@ func generateThumbnail(originalPath, thumbnailDir, size string) error { return fmt.Errorf("invalid height: %v", err) } + // Define the thumbnail file path thumbnailPath := filepath.Join(thumbnailDir, filepath.Base(originalPath)) - log.Infof("Thumbnail will be saved at: %s", thumbnailPath) - err = generateThumbnailWithImaging(originalPath, thumbnailPath, width, height) + // Check if thumbnail already exists + if _, err := os.Stat(thumbnailPath); err == nil { + log.Infof("Thumbnail already exists for %s. Skipping generation.", originalPath) + return nil + } else if !os.IsNotExist(err) { + return fmt.Errorf("error checking thumbnail existence: %v", err) + } + + // Check if ffmpeg is installed + if isFFmpegInstalled() { + // Use ffmpeg to generate the thumbnail + err := generateThumbnailWithFFmpeg(originalPath, thumbnailPath, width, height) + if err != nil { + return fmt.Errorf("failed to generate thumbnail with ffmpeg: %v", err) + } + } else { + // Use Go internal imaging function to generate the thumbnail + err := generateThumbnailWithImaging(originalPath, thumbnailPath, width, height) + if err != nil { + return fmt.Errorf("failed to generate thumbnail with imaging: %v", err) + } + } + + log.Infof("Generated thumbnail for %s at %s", originalPath, thumbnailPath) + thumbnailProcessedTotal.Inc() + return nil +} + +func isFFmpegInstalled() bool { + _, err := exec.LookPath("ffmpeg") + return err == nil +} + +func generateThumbnailWithFFmpeg(originalPath, thumbnailPath string, width, height int) error { + cmd := exec.Command("ffmpeg", "-i", originalPath, "-vf", fmt.Sprintf("thumbnail,scale=%d:%d", width, height), "-frames:v", "1", thumbnailPath) + output, err := cmd.CombinedOutput() if err != nil { - log.Errorf("Thumbnail generation failed for %s -> %s: %v", originalPath, thumbnailPath, err) + log.Errorf("ffmpeg output: %s", string(output)) return err } - log.Infof("Thumbnail generation completed successfully for: %s -> %s", originalPath, thumbnailPath) return nil } -// enhance logging in generateThumbnailWithImaging func generateThumbnailWithImaging(originalPath, thumbnailPath string, width, height int) error { - log.Infof("Using imaging to generate thumbnail: original=%s, thumbnail=%s, size=%dx%d", - originalPath, thumbnailPath, width, height) // Open the original image img, err := imaging.Open(originalPath) if err != nil { @@ -2541,7 +2904,6 @@ func generateThumbnailWithImaging(originalPath, thumbnailPath string, width, hei if err != nil { return fmt.Errorf("failed to save thumbnail: %v", err) } - log.Infof("Thumbnail saved successfully at: %s", thumbnailPath) return nil } @@ -2654,4 +3016,67 @@ func isImageFile(path string) bool { default: return false } +} + +// Add or replace the function to authenticate requests +func authenticateRequest(r *http.Request) bool { + // Placeholder logic; replace with your own authentication method + apiKey := r.Header.Get("X-API-Key") + expectedAPIKey := "your-secure-api-key" + return hmac.Equal([]byte(apiKey), []byte(expectedAPIKey)) +} + +// Add or replace the handler for /thumbnails +func handleThumbnails(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + userID := r.URL.Query().Get("user_id") + if userID == "" { + http.Error(w, "Missing user_id parameter", http.StatusBadRequest) + return + } + + // Authenticate the request + if !authenticateRequest(r) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Construct the thumbnail file path (assuming JPEG) + thumbnailPath := filepath.Join(conf.Thumbnails.Directory, fmt.Sprintf("%s.jpg", userID)) + + fileInfo, err := os.Stat(thumbnailPath) + if os.IsNotExist(err) { + http.Error(w, "Thumbnail not found", http.StatusNotFound) + return + } else if err != nil { + log.WithError(err).Errorf("Error accessing thumbnail for user_id: %s", userID) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + file, err := os.Open(thumbnailPath) + if err != nil { + log.WithError(err).Errorf("Failed to open thumbnail for user_id: %s", userID) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + defer file.Close() + + // Determine the Content-Type based on file extension + ext := strings.ToLower(filepath.Ext(thumbnailPath)) + contentType := mime.TypeByExtension(ext) + if contentType == "" { + contentType = "application/octet-stream" + } + + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Length", strconv.FormatInt(fileInfo.Size(), 10)) + + if _, err := io.Copy(w, file); err != nil { + log.WithError(err).Errorf("Failed to serve thumbnail for user_id: %s", userID) + } } \ No newline at end of file diff --git a/go.mod b/go.mod index a29b631..a075899 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fsnotify/fsnotify v1.7.0 github.com/gdamore/encoding v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect