diff --git a/attachment/backend.go b/attachment/backend.go new file mode 100644 index 00000000..8989b890 --- /dev/null +++ b/attachment/backend.go @@ -0,0 +1,22 @@ +package attachment + +import ( + "io" + "time" +) + +// backendObject represents an object stored in a backend. +type object struct { + ID string + Size int64 + LastModified time.Time +} + +// backend is a minimal I/O interface for storing and retrieving attachment files. +// It has no knowledge of size tracking, limiting, or ID validation. +type backend interface { + Put(id string, in io.Reader) error + Get(id string) (io.ReadCloser, int64, error) + Delete(ids ...string) error + List() ([]object, error) +} diff --git a/attachment/backend_file.go b/attachment/backend_file.go new file mode 100644 index 00000000..b0afb2ca --- /dev/null +++ b/attachment/backend_file.go @@ -0,0 +1,85 @@ +package attachment + +import ( + "io" + "os" + "path/filepath" + + "heckel.io/ntfy/v2/log" +) + +const tagFileBackend = "file_backend" + +type fileBackend struct { + dir string +} + +var _ backend = (*fileBackend)(nil) + +func newFileBackend(dir string) (*fileBackend, error) { + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, err + } + return &fileBackend{dir: dir}, nil +} + +func (b *fileBackend) Put(id string, in io.Reader) error { + file := filepath.Join(b.dir, id) + f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer f.Close() + if _, err := io.Copy(f, in); err != nil { + os.Remove(file) + return err + } + if err := f.Close(); err != nil { + os.Remove(file) + return err + } + return nil +} + +func (b *fileBackend) Get(id string) (io.ReadCloser, int64, error) { + file := filepath.Join(b.dir, id) + stat, err := os.Stat(file) + if err != nil { + return nil, 0, err + } + f, err := os.Open(file) + if err != nil { + return nil, 0, err + } + return f, stat.Size(), nil +} + +func (b *fileBackend) Delete(ids ...string) error { + for _, id := range ids { + file := filepath.Join(b.dir, id) + if err := os.Remove(file); err != nil { + log.Tag(tagFileBackend).Field("message_id", id).Err(err).Debug("Error deleting attachment") + } + } + return nil +} + +func (b *fileBackend) List() ([]object, error) { + entries, err := os.ReadDir(b.dir) + if err != nil { + return nil, err + } + objects := make([]object, 0, len(entries)) + for _, e := range entries { + info, err := e.Info() + if err != nil { + return nil, err + } + objects = append(objects, object{ + ID: e.Name(), + Size: info.Size(), + LastModified: info.ModTime(), + }) + } + return objects, nil +} diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go new file mode 100644 index 00000000..8fcd8ccb --- /dev/null +++ b/attachment/backend_s3.go @@ -0,0 +1,70 @@ +package attachment + +import ( + "context" + "io" + "strings" + + "heckel.io/ntfy/v2/s3" + + "heckel.io/ntfy/v2/log" +) + +const tagS3Backend = "s3_backend" + +type s3Backend struct { + client *s3.Client +} + +var _ backend = (*s3Backend)(nil) + +func newS3Backend(client *s3.Client) *s3Backend { + return &s3Backend{client: client} +} + +func (b *s3Backend) Put(id string, in io.Reader) error { + return b.client.PutObject(context.Background(), id, in) +} + +func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) { + return b.client.GetObject(context.Background(), id) +} + +func (b *s3Backend) Delete(ids ...string) error { + // S3 DeleteObjects supports up to 1000 keys per call + for i := 0; i < len(ids); i += 1000 { + end := i + 1000 + if end > len(ids) { + end = len(ids) + } + batch := ids[i:end] + for _, id := range batch { + log.Tag(tagS3Backend).Field("message_id", id).Debug("Deleting attachment from S3") + } + if err := b.client.DeleteObjects(context.Background(), batch); err != nil { + return err + } + } + return nil +} + +func (b *s3Backend) List() ([]object, error) { + objects, err := b.client.ListAllObjects(context.Background()) + if err != nil { + return nil, err + } + prefix := b.client.Prefix + result := make([]object, 0, len(objects)) + for _, obj := range objects { + id := obj.Key + if prefix != "" { + id = strings.TrimPrefix(id, prefix+"/") + } + result = append(result, object{ + ID: id, + Size: obj.Size, + LastModified: obj.LastModified, + }) + } + return result, nil +} diff --git a/attachment/store.go b/attachment/store.go index 302eb585..f66cd35c 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -5,21 +5,193 @@ import ( "fmt" "io" "regexp" + "sync" + "time" + "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/model" + "heckel.io/ntfy/v2/s3" "heckel.io/ntfy/v2/util" ) +const ( + tagStore = "attachment_cache" + syncInterval = 15 * time.Minute // How often to run the background sync loop + orphanGracePeriod = time.Hour // Don't delete orphaned objects younger than this to avoid races with in-flight uploads +) + var ( fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength)) errInvalidFileID = errors.New("invalid file ID") ) -// Store is an interface for storing and retrieving attachment files -type Store interface { - Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) - Read(id string) (io.ReadCloser, int64, error) - Remove(ids ...string) error - Size() int64 - Remaining() int64 +// Store manages attachment storage with shared logic for size tracking, limiting, +// ID validation, and background sync to reconcile storage with the database. +type Store struct { + backend backend + totalSizeCurrent int64 + totalSizeLimit int64 + localIDs func() ([]string, error) // returns IDs that should exist + closeChan chan struct{} + mu sync.Mutex // Protects totalSizeCurrent +} + +// NewFileStore creates a new file-system backed attachment cache +func NewFileStore(dir string, totalSizeLimit int64, localIDsFn func() ([]string, error)) (*Store, error) { + backend, err := newFileBackend(dir) + if err != nil { + return nil, err + } + return newStore(backend, totalSizeLimit, localIDsFn) +} + +// NewS3Store creates a new S3-backed attachment cache. The s3URL must be in the format: +// +// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] +func NewS3Store(s3URL string, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) { + config, err := s3.ParseURL(s3URL) + if err != nil { + return nil, err + } + return newStore(newS3Backend(s3.New(config)), totalSizeLimit, localIDs) +} + +func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) { + c := &Store{ + backend: backend, + totalSizeLimit: totalSizeLimit, + localIDs: localIDs, + closeChan: make(chan struct{}), + } + if localIDs != nil { + go c.syncLoop() + } + return c, nil +} + +// Write stores an attachment file. The id is validated, and the write is subject to +// the total size limit and any additional limiters. +func (c *Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { + if !fileIDRegex.MatchString(id) { + return 0, errInvalidFileID + } + log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment") + limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) + cr := util.NewCountingReader(in) + lr := util.NewLimitReader(cr, limiters...) + if err := c.backend.Put(id, lr); err != nil { + c.backend.Delete(id) //nolint:errcheck + return 0, err + } + size := cr.Total() + c.mu.Lock() + c.totalSizeCurrent += size + c.mu.Unlock() + return size, nil +} + +// Read retrieves an attachment file by ID +func (c *Store) Read(id string) (io.ReadCloser, int64, error) { + if !fileIDRegex.MatchString(id) { + return nil, 0, errInvalidFileID + } + return c.backend.Get(id) +} + +// Remove deletes attachment files by ID. It does NOT recompute the total size; +// the next sync() call will correct it. +func (c *Store) Remove(ids ...string) error { + for _, id := range ids { + if !fileIDRegex.MatchString(id) { + return errInvalidFileID + } + } + return c.backend.Delete(ids...) +} + +// sync reconciles the backend storage with the database. It lists all objects, +// deletes orphans (not in the valid ID set and older than 1 hour), and recomputes +// the total size from the remaining objects. +func (c *Store) sync() error { + log.Tag(tagStore).Debug("Sync: starting sync loop") + localIDs, err := c.localIDs() + if err != nil { + return fmt.Errorf("attachment sync: failed to get valid IDs: %w", err) + } + localIDMap := make(map[string]struct{}, len(localIDs)) + for _, id := range localIDs { + localIDMap[id] = struct{}{} + } + remoteObjects, err := c.backend.List() + if err != nil { + return fmt.Errorf("attachment sync: failed to list objects: %w", err) + } + // Calculate total cache size and collect orphaned attachments, excluding objects younger + // than the grace period to account for races, and skipping objects with invalid IDs. + cutoff := time.Now().Add(-orphanGracePeriod) + var orphanIDs []string + var totalSize int64 + for _, obj := range remoteObjects { + if !fileIDRegex.MatchString(obj.ID) { + continue + } + if _, ok := localIDMap[obj.ID]; !ok && obj.LastModified.Before(cutoff) { + orphanIDs = append(orphanIDs, obj.ID) + } else { + totalSize += obj.Size + } + } + log.Tag(tagStore).Debug("Sync: cache size updated to %s", util.FormatSizeHuman(totalSize)) + c.mu.Lock() + c.totalSizeCurrent = totalSize + c.mu.Unlock() + // Delete orphaned attachments + if len(orphanIDs) > 0 { + log.Tag(tagStore).Debug("Sync: deleting %d orphaned attachment(s)", len(orphanIDs)) + if err := c.backend.Delete(orphanIDs...); err != nil { + return fmt.Errorf("attachment sync: failed to delete orphaned objects: %w", err) + } + } + return nil +} + +// Size returns the current total size of all attachments +func (c *Store) Size() int64 { + c.mu.Lock() + defer c.mu.Unlock() + return c.totalSizeCurrent +} + +// Remaining returns the remaining capacity for attachments +func (c *Store) Remaining() int64 { + c.mu.Lock() + defer c.mu.Unlock() + remaining := c.totalSizeLimit - c.totalSizeCurrent + if remaining < 0 { + return 0 + } + return remaining +} + +// Close stops the background sync goroutine +func (c *Store) Close() { + close(c.closeChan) +} + +func (c *Store) syncLoop() { + if err := c.sync(); err != nil { + log.Tag(tagStore).Err(err).Warn("Attachment sync failed") + } + ticker := time.NewTicker(syncInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := c.sync(); err != nil { + log.Tag(tagStore).Err(err).Warn("Attachment sync failed") + } + case <-c.closeChan: + return + } + } } diff --git a/attachment/store_file.go b/attachment/store_file.go deleted file mode 100644 index b26e86f0..00000000 --- a/attachment/store_file.go +++ /dev/null @@ -1,139 +0,0 @@ -package attachment - -import ( - "errors" - "io" - "os" - "path/filepath" - "sync" - - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/util" -) - -const tagFileStore = "file_store" - -var errFileExists = errors.New("file exists") - -type fileStore struct { - dir string - totalSizeCurrent int64 - totalSizeLimit int64 - mu sync.Mutex -} - -// NewFileStore creates a new file-system backed attachment store -func NewFileStore(dir string, totalSizeLimit int64) (Store, error) { - if err := os.MkdirAll(dir, 0700); err != nil { - return nil, err - } - size, err := dirSize(dir) - if err != nil { - return nil, err - } - return &fileStore{ - dir: dir, - totalSizeCurrent: size, - totalSizeLimit: totalSizeLimit, - }, nil -} - -func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { - if !fileIDRegex.MatchString(id) { - return 0, errInvalidFileID - } - log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment") - file := filepath.Join(c.dir, id) - if _, err := os.Stat(file); err == nil { - return 0, errFileExists - } - f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) - if err != nil { - return 0, err - } - defer f.Close() - limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - limitWriter := util.NewLimitWriter(f, limiters...) - size, err := io.Copy(limitWriter, in) - if err != nil { - os.Remove(file) - return 0, err - } - if err := f.Close(); err != nil { - os.Remove(file) - return 0, err - } - c.mu.Lock() - c.totalSizeCurrent += size - c.mu.Unlock() - return size, nil -} - -func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) { - if !fileIDRegex.MatchString(id) { - return nil, 0, errInvalidFileID - } - file := filepath.Join(c.dir, id) - stat, err := os.Stat(file) - if err != nil { - return nil, 0, err - } - f, err := os.Open(file) - if err != nil { - return nil, 0, err - } - return f, stat.Size(), nil -} - -func (c *fileStore) Remove(ids ...string) error { - for _, id := range ids { - if !fileIDRegex.MatchString(id) { - return errInvalidFileID - } - log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment") - file := filepath.Join(c.dir, id) - if err := os.Remove(file); err != nil { - log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment") - } - } - size, err := dirSize(c.dir) - if err != nil { - return err - } - c.mu.Lock() - c.totalSizeCurrent = size - c.mu.Unlock() - return nil -} - -func (c *fileStore) Size() int64 { - c.mu.Lock() - defer c.mu.Unlock() - return c.totalSizeCurrent -} - -func (c *fileStore) Remaining() int64 { - c.mu.Lock() - defer c.mu.Unlock() - remaining := c.totalSizeLimit - c.totalSizeCurrent - if remaining < 0 { - return 0 - } - return remaining -} - -func dirSize(dir string) (int64, error) { - entries, err := os.ReadDir(dir) - if err != nil { - return 0, err - } - var size int64 - for _, e := range entries { - info, err := e.Info() - if err != nil { - return 0, err - } - size += info.Size() - } - return size, nil -} diff --git a/attachment/store_file_test.go b/attachment/store_file_test.go index 5cfe0db4..ceac09d7 100644 --- a/attachment/store_file_test.go +++ b/attachment/store_file_test.go @@ -7,6 +7,7 @@ import ( "os" "strings" "testing" + "time" "github.com/stretchr/testify/require" "heckel.io/ntfy/v2/util" @@ -17,22 +18,22 @@ var ( ) func TestFileStore_Write_Success(t *testing.T) { - dir, s := newTestFileStore(t) - size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) + dir, c := newTestFileStore(t) + size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) require.Nil(t, err) require.Equal(t, int64(11), size) require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) - require.Equal(t, int64(11), s.Size()) - require.Equal(t, int64(10229), s.Remaining()) + require.Equal(t, int64(11), c.Size()) + require.Equal(t, int64(10229), c.Remaining()) } func TestFileStore_Write_Read_Success(t *testing.T) { - _, s := newTestFileStore(t) - size, err := s.Write("abcdefghijkl", strings.NewReader("hello world")) + _, c := newTestFileStore(t) + size, err := c.Write("abcdefghijkl", strings.NewReader("hello world")) require.Nil(t, err) require.Equal(t, int64(11), size) - reader, readSize, err := s.Read("abcdefghijkl") + reader, readSize, err := c.Read("abcdefghijkl") require.Nil(t, err) require.Equal(t, int64(11), readSize) defer reader.Close() @@ -42,57 +43,111 @@ func TestFileStore_Write_Read_Success(t *testing.T) { } func TestFileStore_Write_Remove_Success(t *testing.T) { - dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024) + dir, c := newTestFileStore(t) // max = 10k (10240), each = 1k (1024) for i := 0; i < 10; i++ { // 10x999 = 9990 - size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) require.Nil(t, err) require.Equal(t, int64(999), size) } - require.Equal(t, int64(9990), s.Size()) - require.Equal(t, int64(250), s.Remaining()) + require.Equal(t, int64(9990), c.Size()) + require.Equal(t, int64(250), c.Remaining()) require.FileExists(t, dir+"/abcdefghijk1") require.FileExists(t, dir+"/abcdefghijk5") - require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5")) + require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5")) require.NoFileExists(t, dir+"/abcdefghijk1") require.NoFileExists(t, dir+"/abcdefghijk5") - require.Equal(t, int64(7992), s.Size()) - require.Equal(t, int64(2248), s.Remaining()) + // Size is not recomputed by Remove; it stays stale until next sync + require.Equal(t, int64(9990), c.Size()) } func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) { - dir, s := newTestFileStore(t) + dir, c := newTestFileStore(t) for i := 0; i < 10; i++ { - size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) require.Nil(t, err) require.Equal(t, int64(1024), size) } - _, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) + _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abcdefghijkX") } func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) { - dir, s := newTestFileStore(t) - _, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) + dir, c := newTestFileStore(t) + _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abcdefghijkl") } func TestFileStore_Read_NotFound(t *testing.T) { - _, s := newTestFileStore(t) - _, _, err := s.Read("abcdefghijkl") + _, c := newTestFileStore(t) + _, _, err := c.Read("abcdefghijkl") require.Error(t, err) } -func newTestFileStore(t *testing.T) (dir string, store Store) { - dir = t.TempDir() - store, err := NewFileStore(dir, 10*1024) +func TestFileStore_Sync(t *testing.T) { + dir, c := newTestFileStore(t) + + // Write some files + _, err := c.Write("abcdefghijk0", strings.NewReader("file0")) require.Nil(t, err) - return dir, store + _, err = c.Write("abcdefghijk1", strings.NewReader("file1")) + require.Nil(t, err) + _, err = c.Write("abcdefghijk2", strings.NewReader("file2")) + require.Nil(t, err) + + require.Equal(t, int64(15), c.Size()) + + // Set the ID provider to only know about file 0 and 2 + c.localIDs = func() ([]string, error) { + return []string{"abcdefghijk0", "abcdefghijk2"}, nil + } + + // Make file 1's mod time old enough to be cleaned up (> 1 hour) + oldTime := time.Unix(1, 0) + os.Chtimes(dir+"/abcdefghijk1", oldTime, oldTime) + + // Run sync + require.Nil(t, c.sync()) + + // File 1 should be deleted (orphan, old enough) + require.NoFileExists(t, dir+"/abcdefghijk1") + require.FileExists(t, dir+"/abcdefghijk0") + require.FileExists(t, dir+"/abcdefghijk2") + + // Size should be updated + require.Equal(t, int64(10), c.Size()) +} + +func TestFileStore_Sync_SkipsRecentFiles(t *testing.T) { + dir, c := newTestFileStore(t) + + // Write a file + _, err := c.Write("abcdefghijk0", strings.NewReader("file0")) + require.Nil(t, err) + + // Set the ID provider to return empty (no valid IDs) + c.localIDs = func() ([]string, error) { + return []string{}, nil + } + + // File was just created, so it should NOT be deleted (< 1 hour old) + require.Nil(t, c.sync()) + require.FileExists(t, dir+"/abcdefghijk0") +} + +func newTestFileStore(t *testing.T) (dir string, cache *Store) { + t.Helper() + dir = t.TempDir() + cache, err := NewFileStore(dir, 10*1024, nil) + require.Nil(t, err) + t.Cleanup(func() { cache.Close() }) + return dir, cache } func readFile(t *testing.T, f string) string { + t.Helper() b, err := os.ReadFile(f) require.Nil(t, err) return string(b) diff --git a/attachment/store_s3.go b/attachment/store_s3.go deleted file mode 100644 index 38f0353a..00000000 --- a/attachment/store_s3.go +++ /dev/null @@ -1,150 +0,0 @@ -package attachment - -import ( - "context" - "fmt" - "io" - "sync" - - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/s3" - "heckel.io/ntfy/v2/util" -) - -const ( - tagS3Store = "s3_store" -) - -type s3Store struct { - client *s3.Client - totalSizeCurrent int64 - totalSizeLimit int64 - mu sync.Mutex -} - -// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format: -// -// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] -func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) { - cfg, err := s3.ParseURL(s3URL) - if err != nil { - return nil, err - } - store := &s3Store{ - client: s3.New(cfg), - totalSizeLimit: totalSizeLimit, - } - if totalSizeLimit > 0 { - size, err := store.computeSize() - if err != nil { - return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err) - } - store.totalSizeCurrent = size - } - return store, nil -} - -func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { - if !fileIDRegex.MatchString(id) { - return 0, errInvalidFileID - } - log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3") - - // Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked - // uploads, so no temp file or Content-Length is needed. - limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - pr, pw := io.Pipe() - lw := util.NewLimitWriter(pw, limiters...) - var size int64 - var copyErr error - done := make(chan struct{}) - go func() { - defer close(done) - size, copyErr = io.Copy(lw, in) - if copyErr != nil { - pw.CloseWithError(copyErr) - } else { - pw.Close() - } - }() - putErr := c.client.PutObject(context.Background(), id, pr) - pr.Close() - <-done - if copyErr != nil { - return 0, copyErr - } - if putErr != nil { - return 0, putErr - } - c.mu.Lock() - c.totalSizeCurrent += size - c.mu.Unlock() - return size, nil -} - -func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) { - if !fileIDRegex.MatchString(id) { - return nil, 0, errInvalidFileID - } - return c.client.GetObject(context.Background(), id) -} - -func (c *s3Store) Remove(ids ...string) error { - for _, id := range ids { - if !fileIDRegex.MatchString(id) { - return errInvalidFileID - } - } - // S3 DeleteObjects supports up to 1000 keys per call - for i := 0; i < len(ids); i += 1000 { - end := i + 1000 - if end > len(ids) { - end = len(ids) - } - batch := ids[i:end] - for _, id := range batch { - log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3") - } - if err := c.client.DeleteObjects(context.Background(), batch); err != nil { - return err - } - } - // Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern) - size, err := c.computeSize() - if err != nil { - return fmt.Errorf("s3 store: failed to compute size after remove: %w", err) - } - c.mu.Lock() - c.totalSizeCurrent = size - c.mu.Unlock() - return nil -} - -func (c *s3Store) Size() int64 { - c.mu.Lock() - defer c.mu.Unlock() - return c.totalSizeCurrent -} - -func (c *s3Store) Remaining() int64 { - c.mu.Lock() - defer c.mu.Unlock() - remaining := c.totalSizeLimit - c.totalSizeCurrent - if remaining < 0 { - return 0 - } - return remaining -} - -// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix. -func (c *s3Store) computeSize() (int64, error) { - objects, err := c.client.ListAllObjects(context.Background()) - if err != nil { - return 0, err - } - var totalSize int64 - for _, obj := range objects { - totalSize += obj.Size - } - return totalSize, nil -} diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 872e8c23..e29f9ed6 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -10,6 +10,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/stretchr/testify/require" "heckel.io/ntfy/v2/s3" @@ -22,16 +23,16 @@ func TestS3Store_WriteReadRemove(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) // Write - size, err := store.Write("abcdefghijkl", strings.NewReader("hello world")) + size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world")) require.Nil(t, err) require.Equal(t, int64(11), size) - require.Equal(t, int64(11), store.Size()) + require.Equal(t, int64(11), cache.Size()) // Read back - reader, readSize, err := store.Read("abcdefghijkl") + reader, readSize, err := cache.Read("abcdefghijkl") require.Nil(t, err) require.Equal(t, int64(11), readSize) data, err := io.ReadAll(reader) @@ -40,11 +41,11 @@ func TestS3Store_WriteReadRemove(t *testing.T) { require.Equal(t, "hello world", string(data)) // Remove - require.Nil(t, store.Remove("abcdefghijkl")) - require.Equal(t, int64(0), store.Size()) + require.Nil(t, cache.Remove("abcdefghijkl")) + // Size is not recomputed by Remove; stays stale until next sync // Read after remove should fail - _, _, err = store.Read("abcdefghijkl") + _, _, err = cache.Read("abcdefghijkl") require.Error(t, err) } @@ -52,13 +53,13 @@ func TestS3Store_WriteNoPrefix(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "", 10*1024) - size, err := store.Write("abcdefghijkl", strings.NewReader("test")) + size, err := cache.Write("abcdefghijkl", strings.NewReader("test")) require.Nil(t, err) require.Equal(t, int64(4), size) - reader, _, err := store.Read("abcdefghijkl") + reader, _, err := cache.Read("abcdefghijkl") require.Nil(t, err) data, err := io.ReadAll(reader) reader.Close() @@ -70,52 +71,53 @@ func TestS3Store_WriteTotalSizeLimit(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 100) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 100) // First write fits - _, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80))) + _, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80))) require.Nil(t, err) - require.Equal(t, int64(80), store.Size()) - require.Equal(t, int64(20), store.Remaining()) + require.Equal(t, int64(80), cache.Size()) + require.Equal(t, int64(20), cache.Remaining()) // Second write exceeds total limit - _, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50))) - require.Equal(t, util.ErrLimitReached, err) + _, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50))) + require.ErrorIs(t, err, util.ErrLimitReached) } func TestS3Store_WriteFileSizeLimit(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100)) - require.Equal(t, util.ErrLimitReached, err) + _, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100)) + require.ErrorIs(t, err, util.ErrLimitReached) } func TestS3Store_WriteRemoveMultiple(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) for i := 0; i < 5; i++ { - _, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100))) + _, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100))) require.Nil(t, err) } - require.Equal(t, int64(500), store.Size()) + require.Equal(t, int64(500), cache.Size()) - require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3")) - require.Equal(t, int64(300), store.Size()) + require.Nil(t, cache.Remove("abcdefghijk1", "abcdefghijk3")) + // Size not recomputed by Remove + require.Equal(t, int64(500), cache.Size()) } func TestS3Store_ReadNotFound(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, _, err := store.Read("abcdefghijkl") + _, _, err := cache.Read("abcdefghijkl") require.Error(t, err) } @@ -123,42 +125,93 @@ func TestS3Store_InvalidID(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, err := store.Write("bad", strings.NewReader("x")) + _, err := cache.Write("bad", strings.NewReader("x")) require.Equal(t, errInvalidFileID, err) - _, _, err = store.Read("bad") + _, _, err = cache.Read("bad") require.Equal(t, errInvalidFileID, err) - err = store.Remove("bad") + err = cache.Remove("bad") require.Equal(t, errInvalidFileID, err) } +func TestS3Store_Sync(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + + // Write some files + _, err := cache.Write("abcdefghijk0", strings.NewReader("file0")) + require.Nil(t, err) + _, err = cache.Write("abcdefghijk1", strings.NewReader("file1")) + require.Nil(t, err) + _, err = cache.Write("abcdefghijk2", strings.NewReader("file2")) + require.Nil(t, err) + + require.Equal(t, int64(15), cache.Size()) + + // Set the ID provider to only know about file 0 and 2 + // All mock objects have LastModified set to 2 hours ago, so orphans are eligible for deletion + cache.localIDs = func() ([]string, error) { + return []string{"abcdefghijk0", "abcdefghijk2"}, nil + } + + // Run sync + require.Nil(t, cache.sync()) + + // File 1 should be deleted (orphan) + _, _, err = cache.Read("abcdefghijk1") + require.Error(t, err) + + // Size should be updated + require.Equal(t, int64(10), cache.Size()) +} + +func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) { + mockServer := newMockS3ServerWithModTime(time.Now()) + defer mockServer.Close() + + cache := newTestS3Store(t, mockServer, "my-bucket", "pfx", 10*1024) + + _, err := cache.Write("abcdefghijk0", strings.NewReader("file0")) + require.Nil(t, err) + + // Set the ID provider to return empty (no valid IDs) + cache.localIDs = func() ([]string, error) { + return []string{}, nil + } + + // File was "just created" (mock returns recent time), so it should NOT be deleted + require.Nil(t, cache.sync()) + + // File should still exist + reader, _, err := cache.Read("abcdefghijk0") + require.Nil(t, err) + reader.Close() +} + // --- Helpers --- -func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store { +func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store { t.Helper() - // httptest.NewTLSServer URL is like "https://127.0.0.1:PORT" host := strings.TrimPrefix(server.URL, "https://") - s := &s3Store{ - client: &s3.Client{ - AccessKey: "AKID", - SecretKey: "SECRET", - Region: "us-east-1", - Endpoint: host, - Bucket: bucket, - Prefix: prefix, - PathStyle: true, - HTTPClient: server.Client(), - }, - totalSizeLimit: totalSizeLimit, - } - // Compute initial size (should be 0 for fresh mock) - size, err := s.computeSize() + backend := newS3Backend(&s3.Client{ + AccessKey: "AKID", + SecretKey: "SECRET", + Region: "us-east-1", + Endpoint: host, + Bucket: bucket, + Prefix: prefix, + PathStyle: true, + HTTPClient: server.Client(), + }) + cache, err := newStore(backend, totalSizeLimit, nil) require.Nil(t, err) - s.totalSizeCurrent = size - return s + t.Cleanup(func() { cache.Close() }) + return cache } // --- Mock S3 server --- @@ -167,16 +220,22 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body - uploads map[string]map[int][]byte // uploadID -> partNumber -> data - nextID int // counter for generating upload IDs - mu sync.RWMutex + objects map[string][]byte // full key (bucket/key) -> body + uploads map[string]map[int][]byte // uploadID -> partNumber -> data + nextID int // counter for generating upload IDs + lastModTime time.Time // time to return for LastModified in list responses + mu sync.RWMutex } func newMockS3Server() *httptest.Server { + return newMockS3ServerWithModTime(time.Now().Add(-2 * time.Hour)) +} + +func newMockS3ServerWithModTime(modTime time.Time) *httptest.Server { m := &mockS3Server{ - objects: make(map[string][]byte), - uploads: make(map[string]map[int][]byte), + objects: make(map[string][]byte), + uploads: make(map[string]map[int][]byte), + lastModTime: modTime, } return httptest.NewTLSServer(m) } @@ -341,7 +400,11 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket continue // different bucket } if prefix == "" || strings.HasPrefix(objKey, prefix) { - contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))}) + contents = append(contents, s3ListObject{ + Key: objKey, + Size: int64(len(body)), + LastModified: m.lastModTime.Format(time.RFC3339), + }) } } m.mu.RUnlock() @@ -362,6 +425,7 @@ type s3ListResponse struct { } type s3ListObject struct { - Key string `xml:"Key"` - Size int64 `xml:"Size"` + Key string `xml:"Key"` + Size int64 `xml:"Size"` + LastModified string `xml:"LastModified"` } diff --git a/message/cache.go b/message/cache.go index 76aba4be..dd4ef0a4 100644 --- a/message/cache.go +++ b/message/cache.go @@ -46,6 +46,7 @@ type queries struct { selectStats string updateStats string updateMessageTime string + selectAttachmentIDs string } // Cache stores published messages @@ -252,18 +253,7 @@ func (c *Cache) MessagesExpired() ([]string, error) { return nil, err } defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return ids, nil + return readStrings(rows) } // Message returns the message with the given ID, or ErrMessageNotFound if not found @@ -319,18 +309,7 @@ func (c *Cache) Topics() ([]string, error) { return nil, err } defer rows.Close() - topics := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - topics = append(topics, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return topics, nil + return readStrings(rows) } // DeleteMessages deletes the messages with the given IDs @@ -358,15 +337,8 @@ func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, return nil, err } defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { + ids, err := readStrings(rows) + if err != nil { return nil, err } rows.Close() // Close rows before executing delete in same transaction @@ -391,6 +363,16 @@ func (c *Cache) ExpireMessages(topics ...string) error { }) } +// AttachmentIDs returns message IDs with active (non-expired, non-deleted) attachments +func (c *Cache) AttachmentIDs() ([]string, error) { + rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentIDs, time.Now().Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + return readStrings(rows) +} + // AttachmentsExpired returns message IDs with expired attachments that have not been deleted func (c *Cache) AttachmentsExpired() ([]string, error) { rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix()) @@ -398,18 +380,7 @@ func (c *Cache) AttachmentsExpired() ([]string, error) { return nil, err } defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return ids, nil + return readStrings(rows) } // MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted @@ -590,3 +561,18 @@ func readMessage(rows *sql.Rows) (*model.Message, error) { Encoding: encoding, }, nil } + +func readStrings(rows *sql.Rows) ([]string, error) { + strs := make([]string, 0) + for rows.Next() { + var s string + if err := rows.Scan(&s); err != nil { + return nil, err + } + strs = append(strs, s) + } + if err := rows.Err(); err != nil { + return nil, err + } + return strs, nil +} diff --git a/message/cache_postgres.go b/message/cache_postgres.go index ba162da2..d59b2590 100644 --- a/message/cache_postgres.go +++ b/message/cache_postgres.go @@ -74,6 +74,8 @@ const ( postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'` postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'` postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2` + + postgresSelectAttachmentIDsQuery = `SELECT mid FROM message WHERE attachment_expires > $1 AND attachment_deleted = FALSE` ) var postgresQueries = queries{ @@ -100,6 +102,7 @@ var postgresQueries = queries{ selectStats: postgresSelectStatsQuery, updateStats: postgresUpdateStatsQuery, updateMessageTime: postgresUpdateMessageTimeQuery, + selectAttachmentIDs: postgresSelectAttachmentIDsQuery, } // NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool. diff --git a/message/cache_sqlite.go b/message/cache_sqlite.go index a36aba0e..6126f1e1 100644 --- a/message/cache_sqlite.go +++ b/message/cache_sqlite.go @@ -77,6 +77,8 @@ const ( sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'` sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'` sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?` + + sqliteSelectAttachmentIDsQuery = `SELECT mid FROM messages WHERE attachment_expires > ? AND attachment_deleted = 0` ) var sqliteQueries = queries{ @@ -103,6 +105,7 @@ var sqliteQueries = queries{ selectStats: sqliteSelectStatsQuery, updateStats: sqliteUpdateStatsQuery, updateMessageTime: sqliteUpdateMessageTimeQuery, + selectAttachmentIDs: sqliteSelectAttachmentIDsQuery, } // NewSQLiteStore creates a SQLite file-backed cache diff --git a/s3/client.go b/s3/client.go index 56f9608e..5ec8caf6 100644 --- a/s3/client.go +++ b/s3/client.go @@ -196,7 +196,15 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK } objects := make([]Object, len(result.Contents)) for i, obj := range result.Contents { - objects[i] = Object(obj) + var lastModified time.Time + if obj.LastModified != "" { + lastModified, _ = time.Parse(time.RFC3339, obj.LastModified) + } + objects[i] = Object{ + Key: obj.Key, + Size: obj.Size, + LastModified: lastModified, + } } return &ListResult{ Objects: objects, diff --git a/s3/client_test.go b/s3/client_test.go index c3a8fe2c..8007601c 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -13,6 +13,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -243,7 +244,7 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket var contents []listObject for _, objKey := range allKeys { body := m.objects[bucketPath+"/"+objKey] - contents = append(contents, listObject{Key: objKey, Size: int64(len(body))}) + contents = append(contents, listObject{Key: objKey, Size: int64(len(body)), LastModified: time.Now().Format(time.RFC3339)}) } m.mu.RUnlock() diff --git a/s3/types.go b/s3/types.go index 201c570b..65615fcd 100644 --- a/s3/types.go +++ b/s3/types.go @@ -1,6 +1,9 @@ package s3 -import "fmt" +import ( + "fmt" + "time" +) // Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string. type Config struct { @@ -15,8 +18,9 @@ type Config struct { // Object represents an S3 object returned by list operations. type Object struct { - Key string - Size int64 + Key string + Size int64 + LastModified time.Time } // ListResult holds the response from a ListObjectsV2 call. @@ -49,8 +53,9 @@ type listObjectsV2Response struct { } type listObject struct { - Key string `xml:"Key"` - Size int64 `xml:"Size"` + Key string `xml:"Key"` + Size int64 `xml:"Size"` + LastModified string `xml:"LastModified"` } // deleteResult is the XML response from S3 DeleteObjects diff --git a/server/server.go b/server/server.go index 0972f00d..b43b71ef 100644 --- a/server/server.go +++ b/server/server.go @@ -65,7 +65,7 @@ type Server struct { userManager *user.Manager // Might be nil! messageCache *message.Cache // Database that stores the messages webPush *webpush.Store // Database that stores web push subscriptions - fileCache attachment.Store // Attachment store (file system or S3) + fileCache *attachment.Store // Attachment store (file system or S3) stripe stripeAPI // Stripe API, can be replaced with a mock priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set @@ -229,7 +229,7 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } - fileCache, err := createAttachmentStore(conf) + fileCache, err := createAttachmentStore(conf, messageCache) if err != nil { return nil, err } @@ -300,11 +300,14 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) { return message.NewMemStore() } -func createAttachmentStore(conf *Config) (attachment.Store, error) { +func createAttachmentStore(conf *Config, messageCache *message.Cache) (*attachment.Store, error) { + idProvider := func() ([]string, error) { + return messageCache.AttachmentIDs() + } if conf.AttachmentS3URL != "" { - return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit) + return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit, idProvider) } else if conf.AttachmentCacheDir != "" { - return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit) + return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, idProvider) } return nil, nil } @@ -429,6 +432,9 @@ func (s *Server) Stop() { if s.smtpServer != nil { s.smtpServer.Close() } + if s.fileCache != nil { + s.fileCache.Close() + } s.closeDatabases() close(s.closeChan) } diff --git a/util/limit.go b/util/limit.go index ad2118c7..9c39d3dc 100644 --- a/util/limit.go +++ b/util/limit.go @@ -152,6 +152,61 @@ func (l *RateLimiter) Reset() { l.value = 0 } +// CountingReader wraps an io.Reader and counts the number of bytes read through it. +type CountingReader struct { + r io.Reader + total int64 +} + +// NewCountingReader creates a new CountingReader +func NewCountingReader(r io.Reader) *CountingReader { + return &CountingReader{r: r} +} + +// Read passes through to the underlying reader and counts the bytes read +func (r *CountingReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + r.total += int64(n) + return +} + +// Total returns the total number of bytes read so far +func (r *CountingReader) Total() int64 { + return r.total +} + +// LimitReader implements an io.Reader that will pass through all Read calls to the underlying +// reader r until any of the limiter's limit is reached, at which point a Read will return ErrLimitReached. +// Each limiter's value is increased after every read based on the number of bytes actually read. +type LimitReader struct { + r io.Reader + limiters []Limiter +} + +// NewLimitReader creates a new LimitReader +func NewLimitReader(r io.Reader, limiters ...Limiter) *LimitReader { + return &LimitReader{ + r: r, + limiters: limiters, + } +} + +// Read passes through all reads to the underlying reader until any of the given limiter's limit is reached +func (r *LimitReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if n > 0 { + for i := 0; i < len(r.limiters); i++ { + if !r.limiters[i].AllowN(int64(n)) { + for j := i - 1; j >= 0; j-- { + r.limiters[j].AllowN(-int64(n)) // Revert limiters if not allowed + } + return 0, ErrLimitReached + } + } + } + return +} + // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying // writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached. // Each limiter's value is increased with every write.