This commit is contained in:
binwiederhier 2026-03-17 20:53:41 -04:00
parent cffa57950a
commit ef314960d0
16 changed files with 682 additions and 436 deletions

22
attachment/backend.go Normal file
View file

@ -0,0 +1,22 @@
package attachment
import (
"io"
"time"
)
// backendObject represents an object stored in a backend.
type object struct {
ID string
Size int64
LastModified time.Time
}
// backend is a minimal I/O interface for storing and retrieving attachment files.
// It has no knowledge of size tracking, limiting, or ID validation.
type backend interface {
Put(id string, in io.Reader) error
Get(id string) (io.ReadCloser, int64, error)
Delete(ids ...string) error
List() ([]object, error)
}

View file

@ -0,0 +1,85 @@
package attachment
import (
"io"
"os"
"path/filepath"
"heckel.io/ntfy/v2/log"
)
const tagFileBackend = "file_backend"
type fileBackend struct {
dir string
}
var _ backend = (*fileBackend)(nil)
func newFileBackend(dir string) (*fileBackend, error) {
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, err
}
return &fileBackend{dir: dir}, nil
}
func (b *fileBackend) Put(id string, in io.Reader) error {
file := filepath.Join(b.dir, id)
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(f, in); err != nil {
os.Remove(file)
return err
}
if err := f.Close(); err != nil {
os.Remove(file)
return err
}
return nil
}
func (b *fileBackend) Get(id string) (io.ReadCloser, int64, error) {
file := filepath.Join(b.dir, id)
stat, err := os.Stat(file)
if err != nil {
return nil, 0, err
}
f, err := os.Open(file)
if err != nil {
return nil, 0, err
}
return f, stat.Size(), nil
}
func (b *fileBackend) Delete(ids ...string) error {
for _, id := range ids {
file := filepath.Join(b.dir, id)
if err := os.Remove(file); err != nil {
log.Tag(tagFileBackend).Field("message_id", id).Err(err).Debug("Error deleting attachment")
}
}
return nil
}
func (b *fileBackend) List() ([]object, error) {
entries, err := os.ReadDir(b.dir)
if err != nil {
return nil, err
}
objects := make([]object, 0, len(entries))
for _, e := range entries {
info, err := e.Info()
if err != nil {
return nil, err
}
objects = append(objects, object{
ID: e.Name(),
Size: info.Size(),
LastModified: info.ModTime(),
})
}
return objects, nil
}

70
attachment/backend_s3.go Normal file
View file

@ -0,0 +1,70 @@
package attachment
import (
"context"
"io"
"strings"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/log"
)
const tagS3Backend = "s3_backend"
type s3Backend struct {
client *s3.Client
}
var _ backend = (*s3Backend)(nil)
func newS3Backend(client *s3.Client) *s3Backend {
return &s3Backend{client: client}
}
func (b *s3Backend) Put(id string, in io.Reader) error {
return b.client.PutObject(context.Background(), id, in)
}
func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) {
return b.client.GetObject(context.Background(), id)
}
func (b *s3Backend) Delete(ids ...string) error {
// S3 DeleteObjects supports up to 1000 keys per call
for i := 0; i < len(ids); i += 1000 {
end := i + 1000
if end > len(ids) {
end = len(ids)
}
batch := ids[i:end]
for _, id := range batch {
log.Tag(tagS3Backend).Field("message_id", id).Debug("Deleting attachment from S3")
}
if err := b.client.DeleteObjects(context.Background(), batch); err != nil {
return err
}
}
return nil
}
func (b *s3Backend) List() ([]object, error) {
objects, err := b.client.ListAllObjects(context.Background())
if err != nil {
return nil, err
}
prefix := b.client.Prefix
result := make([]object, 0, len(objects))
for _, obj := range objects {
id := obj.Key
if prefix != "" {
id = strings.TrimPrefix(id, prefix+"/")
}
result = append(result, object{
ID: id,
Size: obj.Size,
LastModified: obj.LastModified,
})
}
return result, nil
}

View file

@ -5,21 +5,193 @@ import (
"fmt"
"io"
"regexp"
"sync"
"time"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/util"
)
const (
tagStore = "attachment_cache"
syncInterval = 15 * time.Minute // How often to run the background sync loop
orphanGracePeriod = time.Hour // Don't delete orphaned objects younger than this to avoid races with in-flight uploads
)
var (
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
errInvalidFileID = errors.New("invalid file ID")
)
// Store is an interface for storing and retrieving attachment files
type Store interface {
Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error)
Read(id string) (io.ReadCloser, int64, error)
Remove(ids ...string) error
Size() int64
Remaining() int64
// Store manages attachment storage with shared logic for size tracking, limiting,
// ID validation, and background sync to reconcile storage with the database.
type Store struct {
backend backend
totalSizeCurrent int64
totalSizeLimit int64
localIDs func() ([]string, error) // returns IDs that should exist
closeChan chan struct{}
mu sync.Mutex // Protects totalSizeCurrent
}
// NewFileStore creates a new file-system backed attachment cache
func NewFileStore(dir string, totalSizeLimit int64, localIDsFn func() ([]string, error)) (*Store, error) {
backend, err := newFileBackend(dir)
if err != nil {
return nil, err
}
return newStore(backend, totalSizeLimit, localIDsFn)
}
// NewS3Store creates a new S3-backed attachment cache. The s3URL must be in the format:
//
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
func NewS3Store(s3URL string, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) {
config, err := s3.ParseURL(s3URL)
if err != nil {
return nil, err
}
return newStore(newS3Backend(s3.New(config)), totalSizeLimit, localIDs)
}
func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) {
c := &Store{
backend: backend,
totalSizeLimit: totalSizeLimit,
localIDs: localIDs,
closeChan: make(chan struct{}),
}
if localIDs != nil {
go c.syncLoop()
}
return c, nil
}
// Write stores an attachment file. The id is validated, and the write is subject to
// the total size limit and any additional limiters.
func (c *Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment")
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
cr := util.NewCountingReader(in)
lr := util.NewLimitReader(cr, limiters...)
if err := c.backend.Put(id, lr); err != nil {
c.backend.Delete(id) //nolint:errcheck
return 0, err
}
size := cr.Total()
c.mu.Lock()
c.totalSizeCurrent += size
c.mu.Unlock()
return size, nil
}
// Read retrieves an attachment file by ID
func (c *Store) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
return c.backend.Get(id)
}
// Remove deletes attachment files by ID. It does NOT recompute the total size;
// the next sync() call will correct it.
func (c *Store) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
}
return c.backend.Delete(ids...)
}
// sync reconciles the backend storage with the database. It lists all objects,
// deletes orphans (not in the valid ID set and older than 1 hour), and recomputes
// the total size from the remaining objects.
func (c *Store) sync() error {
log.Tag(tagStore).Debug("Sync: starting sync loop")
localIDs, err := c.localIDs()
if err != nil {
return fmt.Errorf("attachment sync: failed to get valid IDs: %w", err)
}
localIDMap := make(map[string]struct{}, len(localIDs))
for _, id := range localIDs {
localIDMap[id] = struct{}{}
}
remoteObjects, err := c.backend.List()
if err != nil {
return fmt.Errorf("attachment sync: failed to list objects: %w", err)
}
// Calculate total cache size and collect orphaned attachments, excluding objects younger
// than the grace period to account for races, and skipping objects with invalid IDs.
cutoff := time.Now().Add(-orphanGracePeriod)
var orphanIDs []string
var totalSize int64
for _, obj := range remoteObjects {
if !fileIDRegex.MatchString(obj.ID) {
continue
}
if _, ok := localIDMap[obj.ID]; !ok && obj.LastModified.Before(cutoff) {
orphanIDs = append(orphanIDs, obj.ID)
} else {
totalSize += obj.Size
}
}
log.Tag(tagStore).Debug("Sync: cache size updated to %s", util.FormatSizeHuman(totalSize))
c.mu.Lock()
c.totalSizeCurrent = totalSize
c.mu.Unlock()
// Delete orphaned attachments
if len(orphanIDs) > 0 {
log.Tag(tagStore).Debug("Sync: deleting %d orphaned attachment(s)", len(orphanIDs))
if err := c.backend.Delete(orphanIDs...); err != nil {
return fmt.Errorf("attachment sync: failed to delete orphaned objects: %w", err)
}
}
return nil
}
// Size returns the current total size of all attachments
func (c *Store) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
// Remaining returns the remaining capacity for attachments
func (c *Store) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent
if remaining < 0 {
return 0
}
return remaining
}
// Close stops the background sync goroutine
func (c *Store) Close() {
close(c.closeChan)
}
func (c *Store) syncLoop() {
if err := c.sync(); err != nil {
log.Tag(tagStore).Err(err).Warn("Attachment sync failed")
}
ticker := time.NewTicker(syncInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.sync(); err != nil {
log.Tag(tagStore).Err(err).Warn("Attachment sync failed")
}
case <-c.closeChan:
return
}
}
}

View file

@ -1,139 +0,0 @@
package attachment
import (
"errors"
"io"
"os"
"path/filepath"
"sync"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/util"
)
const tagFileStore = "file_store"
var errFileExists = errors.New("file exists")
type fileStore struct {
dir string
totalSizeCurrent int64
totalSizeLimit int64
mu sync.Mutex
}
// NewFileStore creates a new file-system backed attachment store
func NewFileStore(dir string, totalSizeLimit int64) (Store, error) {
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, err
}
size, err := dirSize(dir)
if err != nil {
return nil, err
}
return &fileStore{
dir: dir,
totalSizeCurrent: size,
totalSizeLimit: totalSizeLimit,
}, nil
}
func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment")
file := filepath.Join(c.dir, id)
if _, err := os.Stat(file); err == nil {
return 0, errFileExists
}
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return 0, err
}
defer f.Close()
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
limitWriter := util.NewLimitWriter(f, limiters...)
size, err := io.Copy(limitWriter, in)
if err != nil {
os.Remove(file)
return 0, err
}
if err := f.Close(); err != nil {
os.Remove(file)
return 0, err
}
c.mu.Lock()
c.totalSizeCurrent += size
c.mu.Unlock()
return size, nil
}
func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
file := filepath.Join(c.dir, id)
stat, err := os.Stat(file)
if err != nil {
return nil, 0, err
}
f, err := os.Open(file)
if err != nil {
return nil, 0, err
}
return f, stat.Size(), nil
}
func (c *fileStore) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment")
file := filepath.Join(c.dir, id)
if err := os.Remove(file); err != nil {
log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment")
}
}
size, err := dirSize(c.dir)
if err != nil {
return err
}
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
return nil
}
func (c *fileStore) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *fileStore) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent
if remaining < 0 {
return 0
}
return remaining
}
func dirSize(dir string) (int64, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return 0, err
}
var size int64
for _, e := range entries {
info, err := e.Info()
if err != nil {
return 0, err
}
size += info.Size()
}
return size, nil
}

View file

@ -7,6 +7,7 @@ import (
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/util"
@ -17,22 +18,22 @@ var (
)
func TestFileStore_Write_Success(t *testing.T) {
dir, s := newTestFileStore(t)
size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
dir, c := newTestFileStore(t)
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
require.Equal(t, int64(11), s.Size())
require.Equal(t, int64(10229), s.Remaining())
require.Equal(t, int64(11), c.Size())
require.Equal(t, int64(10229), c.Remaining())
}
func TestFileStore_Write_Read_Success(t *testing.T) {
_, s := newTestFileStore(t)
size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"))
_, c := newTestFileStore(t)
size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"))
require.Nil(t, err)
require.Equal(t, int64(11), size)
reader, readSize, err := s.Read("abcdefghijkl")
reader, readSize, err := c.Read("abcdefghijkl")
require.Nil(t, err)
require.Equal(t, int64(11), readSize)
defer reader.Close()
@ -42,57 +43,111 @@ func TestFileStore_Write_Read_Success(t *testing.T) {
}
func TestFileStore_Write_Remove_Success(t *testing.T) {
dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
dir, c := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
for i := 0; i < 10; i++ { // 10x999 = 9990
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
require.Nil(t, err)
require.Equal(t, int64(999), size)
}
require.Equal(t, int64(9990), s.Size())
require.Equal(t, int64(250), s.Remaining())
require.Equal(t, int64(9990), c.Size())
require.Equal(t, int64(250), c.Remaining())
require.FileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abcdefghijk5")
require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5"))
require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
require.NoFileExists(t, dir+"/abcdefghijk1")
require.NoFileExists(t, dir+"/abcdefghijk5")
require.Equal(t, int64(7992), s.Size())
require.Equal(t, int64(2248), s.Remaining())
// Size is not recomputed by Remove; it stays stale until next sync
require.Equal(t, int64(9990), c.Size())
}
func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) {
dir, s := newTestFileStore(t)
dir, c := newTestFileStore(t)
for i := 0; i < 10; i++ {
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
require.Nil(t, err)
require.Equal(t, int64(1024), size)
}
_, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkX")
}
func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) {
dir, s := newTestFileStore(t)
_, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
dir, c := newTestFileStore(t)
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
require.Equal(t, util.ErrLimitReached, err)
require.NoFileExists(t, dir+"/abcdefghijkl")
}
func TestFileStore_Read_NotFound(t *testing.T) {
_, s := newTestFileStore(t)
_, _, err := s.Read("abcdefghijkl")
_, c := newTestFileStore(t)
_, _, err := c.Read("abcdefghijkl")
require.Error(t, err)
}
func newTestFileStore(t *testing.T) (dir string, store Store) {
dir = t.TempDir()
store, err := NewFileStore(dir, 10*1024)
func TestFileStore_Sync(t *testing.T) {
dir, c := newTestFileStore(t)
// Write some files
_, err := c.Write("abcdefghijk0", strings.NewReader("file0"))
require.Nil(t, err)
return dir, store
_, err = c.Write("abcdefghijk1", strings.NewReader("file1"))
require.Nil(t, err)
_, err = c.Write("abcdefghijk2", strings.NewReader("file2"))
require.Nil(t, err)
require.Equal(t, int64(15), c.Size())
// Set the ID provider to only know about file 0 and 2
c.localIDs = func() ([]string, error) {
return []string{"abcdefghijk0", "abcdefghijk2"}, nil
}
// Make file 1's mod time old enough to be cleaned up (> 1 hour)
oldTime := time.Unix(1, 0)
os.Chtimes(dir+"/abcdefghijk1", oldTime, oldTime)
// Run sync
require.Nil(t, c.sync())
// File 1 should be deleted (orphan, old enough)
require.NoFileExists(t, dir+"/abcdefghijk1")
require.FileExists(t, dir+"/abcdefghijk0")
require.FileExists(t, dir+"/abcdefghijk2")
// Size should be updated
require.Equal(t, int64(10), c.Size())
}
func TestFileStore_Sync_SkipsRecentFiles(t *testing.T) {
dir, c := newTestFileStore(t)
// Write a file
_, err := c.Write("abcdefghijk0", strings.NewReader("file0"))
require.Nil(t, err)
// Set the ID provider to return empty (no valid IDs)
c.localIDs = func() ([]string, error) {
return []string{}, nil
}
// File was just created, so it should NOT be deleted (< 1 hour old)
require.Nil(t, c.sync())
require.FileExists(t, dir+"/abcdefghijk0")
}
func newTestFileStore(t *testing.T) (dir string, cache *Store) {
t.Helper()
dir = t.TempDir()
cache, err := NewFileStore(dir, 10*1024, nil)
require.Nil(t, err)
t.Cleanup(func() { cache.Close() })
return dir, cache
}
func readFile(t *testing.T, f string) string {
t.Helper()
b, err := os.ReadFile(f)
require.Nil(t, err)
return string(b)

View file

@ -1,150 +0,0 @@
package attachment
import (
"context"
"fmt"
"io"
"sync"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/s3"
"heckel.io/ntfy/v2/util"
)
const (
tagS3Store = "s3_store"
)
type s3Store struct {
client *s3.Client
totalSizeCurrent int64
totalSizeLimit int64
mu sync.Mutex
}
// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format:
//
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) {
cfg, err := s3.ParseURL(s3URL)
if err != nil {
return nil, err
}
store := &s3Store{
client: s3.New(cfg),
totalSizeLimit: totalSizeLimit,
}
if totalSizeLimit > 0 {
size, err := store.computeSize()
if err != nil {
return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err)
}
store.totalSizeCurrent = size
}
return store, nil
}
func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
if !fileIDRegex.MatchString(id) {
return 0, errInvalidFileID
}
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
// Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
// uploads, so no temp file or Content-Length is needed.
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
pr, pw := io.Pipe()
lw := util.NewLimitWriter(pw, limiters...)
var size int64
var copyErr error
done := make(chan struct{})
go func() {
defer close(done)
size, copyErr = io.Copy(lw, in)
if copyErr != nil {
pw.CloseWithError(copyErr)
} else {
pw.Close()
}
}()
putErr := c.client.PutObject(context.Background(), id, pr)
pr.Close()
<-done
if copyErr != nil {
return 0, copyErr
}
if putErr != nil {
return 0, putErr
}
c.mu.Lock()
c.totalSizeCurrent += size
c.mu.Unlock()
return size, nil
}
func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) {
if !fileIDRegex.MatchString(id) {
return nil, 0, errInvalidFileID
}
return c.client.GetObject(context.Background(), id)
}
func (c *s3Store) Remove(ids ...string) error {
for _, id := range ids {
if !fileIDRegex.MatchString(id) {
return errInvalidFileID
}
}
// S3 DeleteObjects supports up to 1000 keys per call
for i := 0; i < len(ids); i += 1000 {
end := i + 1000
if end > len(ids) {
end = len(ids)
}
batch := ids[i:end]
for _, id := range batch {
log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3")
}
if err := c.client.DeleteObjects(context.Background(), batch); err != nil {
return err
}
}
// Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern)
size, err := c.computeSize()
if err != nil {
return fmt.Errorf("s3 store: failed to compute size after remove: %w", err)
}
c.mu.Lock()
c.totalSizeCurrent = size
c.mu.Unlock()
return nil
}
func (c *s3Store) Size() int64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.totalSizeCurrent
}
func (c *s3Store) Remaining() int64 {
c.mu.Lock()
defer c.mu.Unlock()
remaining := c.totalSizeLimit - c.totalSizeCurrent
if remaining < 0 {
return 0
}
return remaining
}
// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix.
func (c *s3Store) computeSize() (int64, error) {
objects, err := c.client.ListAllObjects(context.Background())
if err != nil {
return 0, err
}
var totalSize int64
for _, obj := range objects {
totalSize += obj.Size
}
return totalSize, nil
}

View file

@ -10,6 +10,7 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"heckel.io/ntfy/v2/s3"
@ -22,16 +23,16 @@ func TestS3Store_WriteReadRemove(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
// Write
size, err := store.Write("abcdefghijkl", strings.NewReader("hello world"))
size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world"))
require.Nil(t, err)
require.Equal(t, int64(11), size)
require.Equal(t, int64(11), store.Size())
require.Equal(t, int64(11), cache.Size())
// Read back
reader, readSize, err := store.Read("abcdefghijkl")
reader, readSize, err := cache.Read("abcdefghijkl")
require.Nil(t, err)
require.Equal(t, int64(11), readSize)
data, err := io.ReadAll(reader)
@ -40,11 +41,11 @@ func TestS3Store_WriteReadRemove(t *testing.T) {
require.Equal(t, "hello world", string(data))
// Remove
require.Nil(t, store.Remove("abcdefghijkl"))
require.Equal(t, int64(0), store.Size())
require.Nil(t, cache.Remove("abcdefghijkl"))
// Size is not recomputed by Remove; stays stale until next sync
// Read after remove should fail
_, _, err = store.Read("abcdefghijkl")
_, _, err = cache.Read("abcdefghijkl")
require.Error(t, err)
}
@ -52,13 +53,13 @@ func TestS3Store_WriteNoPrefix(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "", 10*1024)
cache := newTestS3Store(t, server, "my-bucket", "", 10*1024)
size, err := store.Write("abcdefghijkl", strings.NewReader("test"))
size, err := cache.Write("abcdefghijkl", strings.NewReader("test"))
require.Nil(t, err)
require.Equal(t, int64(4), size)
reader, _, err := store.Read("abcdefghijkl")
reader, _, err := cache.Read("abcdefghijkl")
require.Nil(t, err)
data, err := io.ReadAll(reader)
reader.Close()
@ -70,52 +71,53 @@ func TestS3Store_WriteTotalSizeLimit(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 100)
cache := newTestS3Store(t, server, "my-bucket", "pfx", 100)
// First write fits
_, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
_, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
require.Nil(t, err)
require.Equal(t, int64(80), store.Size())
require.Equal(t, int64(20), store.Remaining())
require.Equal(t, int64(80), cache.Size())
require.Equal(t, int64(20), cache.Remaining())
// Second write exceeds total limit
_, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
require.Equal(t, util.ErrLimitReached, err)
_, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
require.ErrorIs(t, err, util.ErrLimitReached)
}
func TestS3Store_WriteFileSizeLimit(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
require.Equal(t, util.ErrLimitReached, err)
_, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
require.ErrorIs(t, err, util.ErrLimitReached)
}
func TestS3Store_WriteRemoveMultiple(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
for i := 0; i < 5; i++ {
_, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
_, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
require.Nil(t, err)
}
require.Equal(t, int64(500), store.Size())
require.Equal(t, int64(500), cache.Size())
require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3"))
require.Equal(t, int64(300), store.Size())
require.Nil(t, cache.Remove("abcdefghijk1", "abcdefghijk3"))
// Size not recomputed by Remove
require.Equal(t, int64(500), cache.Size())
}
func TestS3Store_ReadNotFound(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, _, err := store.Read("abcdefghijkl")
_, _, err := cache.Read("abcdefghijkl")
require.Error(t, err)
}
@ -123,26 +125,80 @@ func TestS3Store_InvalidID(t *testing.T) {
server := newMockS3Server()
defer server.Close()
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
_, err := store.Write("bad", strings.NewReader("x"))
_, err := cache.Write("bad", strings.NewReader("x"))
require.Equal(t, errInvalidFileID, err)
_, _, err = store.Read("bad")
_, _, err = cache.Read("bad")
require.Equal(t, errInvalidFileID, err)
err = store.Remove("bad")
err = cache.Remove("bad")
require.Equal(t, errInvalidFileID, err)
}
func TestS3Store_Sync(t *testing.T) {
server := newMockS3Server()
defer server.Close()
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
// Write some files
_, err := cache.Write("abcdefghijk0", strings.NewReader("file0"))
require.Nil(t, err)
_, err = cache.Write("abcdefghijk1", strings.NewReader("file1"))
require.Nil(t, err)
_, err = cache.Write("abcdefghijk2", strings.NewReader("file2"))
require.Nil(t, err)
require.Equal(t, int64(15), cache.Size())
// Set the ID provider to only know about file 0 and 2
// All mock objects have LastModified set to 2 hours ago, so orphans are eligible for deletion
cache.localIDs = func() ([]string, error) {
return []string{"abcdefghijk0", "abcdefghijk2"}, nil
}
// Run sync
require.Nil(t, cache.sync())
// File 1 should be deleted (orphan)
_, _, err = cache.Read("abcdefghijk1")
require.Error(t, err)
// Size should be updated
require.Equal(t, int64(10), cache.Size())
}
func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) {
mockServer := newMockS3ServerWithModTime(time.Now())
defer mockServer.Close()
cache := newTestS3Store(t, mockServer, "my-bucket", "pfx", 10*1024)
_, err := cache.Write("abcdefghijk0", strings.NewReader("file0"))
require.Nil(t, err)
// Set the ID provider to return empty (no valid IDs)
cache.localIDs = func() ([]string, error) {
return []string{}, nil
}
// File was "just created" (mock returns recent time), so it should NOT be deleted
require.Nil(t, cache.sync())
// File should still exist
reader, _, err := cache.Read("abcdefghijk0")
require.Nil(t, err)
reader.Close()
}
// --- Helpers ---
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store {
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store {
t.Helper()
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
host := strings.TrimPrefix(server.URL, "https://")
s := &s3Store{
client: &s3.Client{
backend := newS3Backend(&s3.Client{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
@ -151,14 +207,11 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
},
totalSizeLimit: totalSizeLimit,
}
// Compute initial size (should be 0 for fresh mock)
size, err := s.computeSize()
})
cache, err := newStore(backend, totalSizeLimit, nil)
require.Nil(t, err)
s.totalSizeCurrent = size
return s
t.Cleanup(func() { cache.Close() })
return cache
}
// --- Mock S3 server ---
@ -170,13 +223,19 @@ type mockS3Server struct {
objects map[string][]byte // full key (bucket/key) -> body
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
nextID int // counter for generating upload IDs
lastModTime time.Time // time to return for LastModified in list responses
mu sync.RWMutex
}
func newMockS3Server() *httptest.Server {
return newMockS3ServerWithModTime(time.Now().Add(-2 * time.Hour))
}
func newMockS3ServerWithModTime(modTime time.Time) *httptest.Server {
m := &mockS3Server{
objects: make(map[string][]byte),
uploads: make(map[string]map[int][]byte),
lastModTime: modTime,
}
return httptest.NewTLSServer(m)
}
@ -341,7 +400,11 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket
continue // different bucket
}
if prefix == "" || strings.HasPrefix(objKey, prefix) {
contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))})
contents = append(contents, s3ListObject{
Key: objKey,
Size: int64(len(body)),
LastModified: m.lastModTime.Format(time.RFC3339),
})
}
}
m.mu.RUnlock()
@ -364,4 +427,5 @@ type s3ListResponse struct {
type s3ListObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
LastModified string `xml:"LastModified"`
}

View file

@ -46,6 +46,7 @@ type queries struct {
selectStats string
updateStats string
updateMessageTime string
selectAttachmentIDs string
}
// Cache stores published messages
@ -252,18 +253,7 @@ func (c *Cache) MessagesExpired() ([]string, error) {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
return readStrings(rows)
}
// Message returns the message with the given ID, or ErrMessageNotFound if not found
@ -319,18 +309,7 @@ func (c *Cache) Topics() ([]string, error) {
return nil, err
}
defer rows.Close()
topics := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
topics = append(topics, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return topics, nil
return readStrings(rows)
}
// DeleteMessages deletes the messages with the given IDs
@ -358,15 +337,8 @@ func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string,
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
ids, err := readStrings(rows)
if err != nil {
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
@ -391,6 +363,16 @@ func (c *Cache) ExpireMessages(topics ...string) error {
})
}
// AttachmentIDs returns message IDs with active (non-expired, non-deleted) attachments
func (c *Cache) AttachmentIDs() ([]string, error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentIDs, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
return readStrings(rows)
}
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
func (c *Cache) AttachmentsExpired() ([]string, error) {
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
@ -398,18 +380,7 @@ func (c *Cache) AttachmentsExpired() ([]string, error) {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, err
}
return ids, nil
return readStrings(rows)
}
// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted
@ -590,3 +561,18 @@ func readMessage(rows *sql.Rows) (*model.Message, error) {
Encoding: encoding,
}, nil
}
func readStrings(rows *sql.Rows) ([]string, error) {
strs := make([]string, 0)
for rows.Next() {
var s string
if err := rows.Scan(&s); err != nil {
return nil, err
}
strs = append(strs, s)
}
if err := rows.Err(); err != nil {
return nil, err
}
return strs, nil
}

View file

@ -74,6 +74,8 @@ const (
postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2`
postgresSelectAttachmentIDsQuery = `SELECT mid FROM message WHERE attachment_expires > $1 AND attachment_deleted = FALSE`
)
var postgresQueries = queries{
@ -100,6 +102,7 @@ var postgresQueries = queries{
selectStats: postgresSelectStatsQuery,
updateStats: postgresUpdateStatsQuery,
updateMessageTime: postgresUpdateMessageTimeQuery,
selectAttachmentIDs: postgresSelectAttachmentIDsQuery,
}
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.

View file

@ -77,6 +77,8 @@ const (
sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
sqliteSelectAttachmentIDsQuery = `SELECT mid FROM messages WHERE attachment_expires > ? AND attachment_deleted = 0`
)
var sqliteQueries = queries{
@ -103,6 +105,7 @@ var sqliteQueries = queries{
selectStats: sqliteSelectStatsQuery,
updateStats: sqliteUpdateStatsQuery,
updateMessageTime: sqliteUpdateMessageTimeQuery,
selectAttachmentIDs: sqliteSelectAttachmentIDsQuery,
}
// NewSQLiteStore creates a SQLite file-backed cache

View file

@ -196,7 +196,15 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK
}
objects := make([]Object, len(result.Contents))
for i, obj := range result.Contents {
objects[i] = Object(obj)
var lastModified time.Time
if obj.LastModified != "" {
lastModified, _ = time.Parse(time.RFC3339, obj.LastModified)
}
objects[i] = Object{
Key: obj.Key,
Size: obj.Size,
LastModified: lastModified,
}
}
return &ListResult{
Objects: objects,

View file

@ -13,6 +13,7 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
@ -243,7 +244,7 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket
var contents []listObject
for _, objKey := range allKeys {
body := m.objects[bucketPath+"/"+objKey]
contents = append(contents, listObject{Key: objKey, Size: int64(len(body))})
contents = append(contents, listObject{Key: objKey, Size: int64(len(body)), LastModified: time.Now().Format(time.RFC3339)})
}
m.mu.RUnlock()

View file

@ -1,6 +1,9 @@
package s3
import "fmt"
import (
"fmt"
"time"
)
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
type Config struct {
@ -17,6 +20,7 @@ type Config struct {
type Object struct {
Key string
Size int64
LastModified time.Time
}
// ListResult holds the response from a ListObjectsV2 call.
@ -51,6 +55,7 @@ type listObjectsV2Response struct {
type listObject struct {
Key string `xml:"Key"`
Size int64 `xml:"Size"`
LastModified string `xml:"LastModified"`
}
// deleteResult is the XML response from S3 DeleteObjects

View file

@ -65,7 +65,7 @@ type Server struct {
userManager *user.Manager // Might be nil!
messageCache *message.Cache // Database that stores the messages
webPush *webpush.Store // Database that stores web push subscriptions
fileCache attachment.Store // Attachment store (file system or S3)
fileCache *attachment.Store // Attachment store (file system or S3)
stripe stripeAPI // Stripe API, can be replaced with a mock
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
@ -229,7 +229,7 @@ func New(conf *Config) (*Server, error) {
if err != nil {
return nil, err
}
fileCache, err := createAttachmentStore(conf)
fileCache, err := createAttachmentStore(conf, messageCache)
if err != nil {
return nil, err
}
@ -300,11 +300,14 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
return message.NewMemStore()
}
func createAttachmentStore(conf *Config) (attachment.Store, error) {
func createAttachmentStore(conf *Config, messageCache *message.Cache) (*attachment.Store, error) {
idProvider := func() ([]string, error) {
return messageCache.AttachmentIDs()
}
if conf.AttachmentS3URL != "" {
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit)
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit, idProvider)
} else if conf.AttachmentCacheDir != "" {
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, idProvider)
}
return nil, nil
}
@ -429,6 +432,9 @@ func (s *Server) Stop() {
if s.smtpServer != nil {
s.smtpServer.Close()
}
if s.fileCache != nil {
s.fileCache.Close()
}
s.closeDatabases()
close(s.closeChan)
}

View file

@ -152,6 +152,61 @@ func (l *RateLimiter) Reset() {
l.value = 0
}
// CountingReader wraps an io.Reader and counts the number of bytes read through it.
type CountingReader struct {
r io.Reader
total int64
}
// NewCountingReader creates a new CountingReader
func NewCountingReader(r io.Reader) *CountingReader {
return &CountingReader{r: r}
}
// Read passes through to the underlying reader and counts the bytes read
func (r *CountingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.total += int64(n)
return
}
// Total returns the total number of bytes read so far
func (r *CountingReader) Total() int64 {
return r.total
}
// LimitReader implements an io.Reader that will pass through all Read calls to the underlying
// reader r until any of the limiter's limit is reached, at which point a Read will return ErrLimitReached.
// Each limiter's value is increased after every read based on the number of bytes actually read.
type LimitReader struct {
r io.Reader
limiters []Limiter
}
// NewLimitReader creates a new LimitReader
func NewLimitReader(r io.Reader, limiters ...Limiter) *LimitReader {
return &LimitReader{
r: r,
limiters: limiters,
}
}
// Read passes through all reads to the underlying reader until any of the given limiter's limit is reached
func (r *LimitReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
if n > 0 {
for i := 0; i < len(r.limiters); i++ {
if !r.limiters[i].AllowN(int64(n)) {
for j := i - 1; j >= 0; j-- {
r.limiters[j].AllowN(-int64(n)) // Revert limiters if not allowed
}
return 0, ErrLimitReached
}
}
}
return
}
// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
// writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
// Each limiter's value is increased with every write.