Refactor
This commit is contained in:
parent
cffa57950a
commit
ef314960d0
16 changed files with 682 additions and 436 deletions
22
attachment/backend.go
Normal file
22
attachment/backend.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
package attachment
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// backendObject represents an object stored in a backend.
|
||||
type object struct {
|
||||
ID string
|
||||
Size int64
|
||||
LastModified time.Time
|
||||
}
|
||||
|
||||
// backend is a minimal I/O interface for storing and retrieving attachment files.
|
||||
// It has no knowledge of size tracking, limiting, or ID validation.
|
||||
type backend interface {
|
||||
Put(id string, in io.Reader) error
|
||||
Get(id string) (io.ReadCloser, int64, error)
|
||||
Delete(ids ...string) error
|
||||
List() ([]object, error)
|
||||
}
|
||||
85
attachment/backend_file.go
Normal file
85
attachment/backend_file.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
package attachment
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
const tagFileBackend = "file_backend"
|
||||
|
||||
type fileBackend struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
var _ backend = (*fileBackend)(nil)
|
||||
|
||||
func newFileBackend(dir string) (*fileBackend, error) {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fileBackend{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (b *fileBackend) Put(id string, in io.Reader) error {
|
||||
file := filepath.Join(b.dir, id)
|
||||
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if _, err := io.Copy(f, in); err != nil {
|
||||
os.Remove(file)
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
os.Remove(file)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *fileBackend) Get(id string) (io.ReadCloser, int64, error) {
|
||||
file := filepath.Join(b.dir, id)
|
||||
stat, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return f, stat.Size(), nil
|
||||
}
|
||||
|
||||
func (b *fileBackend) Delete(ids ...string) error {
|
||||
for _, id := range ids {
|
||||
file := filepath.Join(b.dir, id)
|
||||
if err := os.Remove(file); err != nil {
|
||||
log.Tag(tagFileBackend).Field("message_id", id).Err(err).Debug("Error deleting attachment")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *fileBackend) List() ([]object, error) {
|
||||
entries, err := os.ReadDir(b.dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects := make([]object, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
info, err := e.Info()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, object{
|
||||
ID: e.Name(),
|
||||
Size: info.Size(),
|
||||
LastModified: info.ModTime(),
|
||||
})
|
||||
}
|
||||
return objects, nil
|
||||
}
|
||||
70
attachment/backend_s3.go
Normal file
70
attachment/backend_s3.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package attachment
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
const tagS3Backend = "s3_backend"
|
||||
|
||||
type s3Backend struct {
|
||||
client *s3.Client
|
||||
}
|
||||
|
||||
var _ backend = (*s3Backend)(nil)
|
||||
|
||||
func newS3Backend(client *s3.Client) *s3Backend {
|
||||
return &s3Backend{client: client}
|
||||
}
|
||||
|
||||
func (b *s3Backend) Put(id string, in io.Reader) error {
|
||||
return b.client.PutObject(context.Background(), id, in)
|
||||
}
|
||||
|
||||
func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) {
|
||||
return b.client.GetObject(context.Background(), id)
|
||||
}
|
||||
|
||||
func (b *s3Backend) Delete(ids ...string) error {
|
||||
// S3 DeleteObjects supports up to 1000 keys per call
|
||||
for i := 0; i < len(ids); i += 1000 {
|
||||
end := i + 1000
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
batch := ids[i:end]
|
||||
for _, id := range batch {
|
||||
log.Tag(tagS3Backend).Field("message_id", id).Debug("Deleting attachment from S3")
|
||||
}
|
||||
if err := b.client.DeleteObjects(context.Background(), batch); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *s3Backend) List() ([]object, error) {
|
||||
objects, err := b.client.ListAllObjects(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prefix := b.client.Prefix
|
||||
result := make([]object, 0, len(objects))
|
||||
for _, obj := range objects {
|
||||
id := obj.Key
|
||||
if prefix != "" {
|
||||
id = strings.TrimPrefix(id, prefix+"/")
|
||||
}
|
||||
result = append(result, object{
|
||||
ID: id,
|
||||
Size: obj.Size,
|
||||
LastModified: obj.LastModified,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -5,21 +5,193 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/model"
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
tagStore = "attachment_cache"
|
||||
syncInterval = 15 * time.Minute // How often to run the background sync loop
|
||||
orphanGracePeriod = time.Hour // Don't delete orphaned objects younger than this to avoid races with in-flight uploads
|
||||
)
|
||||
|
||||
var (
|
||||
fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength))
|
||||
errInvalidFileID = errors.New("invalid file ID")
|
||||
)
|
||||
|
||||
// Store is an interface for storing and retrieving attachment files
|
||||
type Store interface {
|
||||
Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error)
|
||||
Read(id string) (io.ReadCloser, int64, error)
|
||||
Remove(ids ...string) error
|
||||
Size() int64
|
||||
Remaining() int64
|
||||
// Store manages attachment storage with shared logic for size tracking, limiting,
|
||||
// ID validation, and background sync to reconcile storage with the database.
|
||||
type Store struct {
|
||||
backend backend
|
||||
totalSizeCurrent int64
|
||||
totalSizeLimit int64
|
||||
localIDs func() ([]string, error) // returns IDs that should exist
|
||||
closeChan chan struct{}
|
||||
mu sync.Mutex // Protects totalSizeCurrent
|
||||
}
|
||||
|
||||
// NewFileStore creates a new file-system backed attachment cache
|
||||
func NewFileStore(dir string, totalSizeLimit int64, localIDsFn func() ([]string, error)) (*Store, error) {
|
||||
backend, err := newFileBackend(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newStore(backend, totalSizeLimit, localIDsFn)
|
||||
}
|
||||
|
||||
// NewS3Store creates a new S3-backed attachment cache. The s3URL must be in the format:
|
||||
//
|
||||
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
func NewS3Store(s3URL string, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) {
|
||||
config, err := s3.ParseURL(s3URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newStore(newS3Backend(s3.New(config)), totalSizeLimit, localIDs)
|
||||
}
|
||||
|
||||
func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) {
|
||||
c := &Store{
|
||||
backend: backend,
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
localIDs: localIDs,
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
if localIDs != nil {
|
||||
go c.syncLoop()
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Write stores an attachment file. The id is validated, and the write is subject to
|
||||
// the total size limit and any additional limiters.
|
||||
func (c *Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return 0, errInvalidFileID
|
||||
}
|
||||
log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment")
|
||||
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
|
||||
cr := util.NewCountingReader(in)
|
||||
lr := util.NewLimitReader(cr, limiters...)
|
||||
if err := c.backend.Put(id, lr); err != nil {
|
||||
c.backend.Delete(id) //nolint:errcheck
|
||||
return 0, err
|
||||
}
|
||||
size := cr.Total()
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent += size
|
||||
c.mu.Unlock()
|
||||
return size, nil
|
||||
}
|
||||
|
||||
// Read retrieves an attachment file by ID
|
||||
func (c *Store) Read(id string) (io.ReadCloser, int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return nil, 0, errInvalidFileID
|
||||
}
|
||||
return c.backend.Get(id)
|
||||
}
|
||||
|
||||
// Remove deletes attachment files by ID. It does NOT recompute the total size;
|
||||
// the next sync() call will correct it.
|
||||
func (c *Store) Remove(ids ...string) error {
|
||||
for _, id := range ids {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return errInvalidFileID
|
||||
}
|
||||
}
|
||||
return c.backend.Delete(ids...)
|
||||
}
|
||||
|
||||
// sync reconciles the backend storage with the database. It lists all objects,
|
||||
// deletes orphans (not in the valid ID set and older than 1 hour), and recomputes
|
||||
// the total size from the remaining objects.
|
||||
func (c *Store) sync() error {
|
||||
log.Tag(tagStore).Debug("Sync: starting sync loop")
|
||||
localIDs, err := c.localIDs()
|
||||
if err != nil {
|
||||
return fmt.Errorf("attachment sync: failed to get valid IDs: %w", err)
|
||||
}
|
||||
localIDMap := make(map[string]struct{}, len(localIDs))
|
||||
for _, id := range localIDs {
|
||||
localIDMap[id] = struct{}{}
|
||||
}
|
||||
remoteObjects, err := c.backend.List()
|
||||
if err != nil {
|
||||
return fmt.Errorf("attachment sync: failed to list objects: %w", err)
|
||||
}
|
||||
// Calculate total cache size and collect orphaned attachments, excluding objects younger
|
||||
// than the grace period to account for races, and skipping objects with invalid IDs.
|
||||
cutoff := time.Now().Add(-orphanGracePeriod)
|
||||
var orphanIDs []string
|
||||
var totalSize int64
|
||||
for _, obj := range remoteObjects {
|
||||
if !fileIDRegex.MatchString(obj.ID) {
|
||||
continue
|
||||
}
|
||||
if _, ok := localIDMap[obj.ID]; !ok && obj.LastModified.Before(cutoff) {
|
||||
orphanIDs = append(orphanIDs, obj.ID)
|
||||
} else {
|
||||
totalSize += obj.Size
|
||||
}
|
||||
}
|
||||
log.Tag(tagStore).Debug("Sync: cache size updated to %s", util.FormatSizeHuman(totalSize))
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent = totalSize
|
||||
c.mu.Unlock()
|
||||
// Delete orphaned attachments
|
||||
if len(orphanIDs) > 0 {
|
||||
log.Tag(tagStore).Debug("Sync: deleting %d orphaned attachment(s)", len(orphanIDs))
|
||||
if err := c.backend.Delete(orphanIDs...); err != nil {
|
||||
return fmt.Errorf("attachment sync: failed to delete orphaned objects: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size returns the current total size of all attachments
|
||||
func (c *Store) Size() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.totalSizeCurrent
|
||||
}
|
||||
|
||||
// Remaining returns the remaining capacity for attachments
|
||||
func (c *Store) Remaining() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.totalSizeLimit - c.totalSizeCurrent
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
// Close stops the background sync goroutine
|
||||
func (c *Store) Close() {
|
||||
close(c.closeChan)
|
||||
}
|
||||
|
||||
func (c *Store) syncLoop() {
|
||||
if err := c.sync(); err != nil {
|
||||
log.Tag(tagStore).Err(err).Warn("Attachment sync failed")
|
||||
}
|
||||
ticker := time.NewTicker(syncInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := c.sync(); err != nil {
|
||||
log.Tag(tagStore).Err(err).Warn("Attachment sync failed")
|
||||
}
|
||||
case <-c.closeChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,139 +0,0 @@
|
|||
package attachment
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const tagFileStore = "file_store"
|
||||
|
||||
var errFileExists = errors.New("file exists")
|
||||
|
||||
type fileStore struct {
|
||||
dir string
|
||||
totalSizeCurrent int64
|
||||
totalSizeLimit int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewFileStore creates a new file-system backed attachment store
|
||||
func NewFileStore(dir string, totalSizeLimit int64) (Store, error) {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
size, err := dirSize(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fileStore{
|
||||
dir: dir,
|
||||
totalSizeCurrent: size,
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return 0, errInvalidFileID
|
||||
}
|
||||
log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment")
|
||||
file := filepath.Join(c.dir, id)
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
return 0, errFileExists
|
||||
}
|
||||
f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
|
||||
limitWriter := util.NewLimitWriter(f, limiters...)
|
||||
size, err := io.Copy(limitWriter, in)
|
||||
if err != nil {
|
||||
os.Remove(file)
|
||||
return 0, err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
os.Remove(file)
|
||||
return 0, err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent += size
|
||||
c.mu.Unlock()
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return nil, 0, errInvalidFileID
|
||||
}
|
||||
file := filepath.Join(c.dir, id)
|
||||
stat, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return f, stat.Size(), nil
|
||||
}
|
||||
|
||||
func (c *fileStore) Remove(ids ...string) error {
|
||||
for _, id := range ids {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return errInvalidFileID
|
||||
}
|
||||
log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment")
|
||||
file := filepath.Join(c.dir, id)
|
||||
if err := os.Remove(file); err != nil {
|
||||
log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment")
|
||||
}
|
||||
}
|
||||
size, err := dirSize(c.dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent = size
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fileStore) Size() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.totalSizeCurrent
|
||||
}
|
||||
|
||||
func (c *fileStore) Remaining() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.totalSizeLimit - c.totalSizeCurrent
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func dirSize(dir string) (int64, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var size int64
|
||||
for _, e := range entries {
|
||||
info, err := e.Info()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
size += info.Size()
|
||||
}
|
||||
return size, nil
|
||||
}
|
||||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
|
|
@ -17,22 +18,22 @@ var (
|
|||
)
|
||||
|
||||
func TestFileStore_Write_Success(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
|
||||
dir, c := newTestFileStore(t)
|
||||
size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl"))
|
||||
require.Equal(t, int64(11), s.Size())
|
||||
require.Equal(t, int64(10229), s.Remaining())
|
||||
require.Equal(t, int64(11), c.Size())
|
||||
require.Equal(t, int64(10229), c.Remaining())
|
||||
}
|
||||
|
||||
func TestFileStore_Write_Read_Success(t *testing.T) {
|
||||
_, s := newTestFileStore(t)
|
||||
size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"))
|
||||
_, c := newTestFileStore(t)
|
||||
size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
|
||||
reader, readSize, err := s.Read("abcdefghijkl")
|
||||
reader, readSize, err := c.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), readSize)
|
||||
defer reader.Close()
|
||||
|
|
@ -42,57 +43,111 @@ func TestFileStore_Write_Read_Success(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFileStore_Write_Remove_Success(t *testing.T) {
|
||||
dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
|
||||
dir, c := newTestFileStore(t) // max = 10k (10240), each = 1k (1024)
|
||||
for i := 0; i < 10; i++ { // 10x999 = 9990
|
||||
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
|
||||
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(999), size)
|
||||
}
|
||||
require.Equal(t, int64(9990), s.Size())
|
||||
require.Equal(t, int64(250), s.Remaining())
|
||||
require.Equal(t, int64(9990), c.Size())
|
||||
require.Equal(t, int64(250), c.Remaining())
|
||||
require.FileExists(t, dir+"/abcdefghijk1")
|
||||
require.FileExists(t, dir+"/abcdefghijk5")
|
||||
|
||||
require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5"))
|
||||
require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5"))
|
||||
require.NoFileExists(t, dir+"/abcdefghijk1")
|
||||
require.NoFileExists(t, dir+"/abcdefghijk5")
|
||||
require.Equal(t, int64(7992), s.Size())
|
||||
require.Equal(t, int64(2248), s.Remaining())
|
||||
// Size is not recomputed by Remove; it stays stale until next sync
|
||||
require.Equal(t, int64(9990), c.Size())
|
||||
}
|
||||
|
||||
func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
dir, c := newTestFileStore(t)
|
||||
for i := 0; i < 10; i++ {
|
||||
size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
|
||||
size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(1024), size)
|
||||
}
|
||||
_, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
|
||||
_, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkX")
|
||||
}
|
||||
|
||||
func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) {
|
||||
dir, s := newTestFileStore(t)
|
||||
_, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
|
||||
dir, c := newTestFileStore(t)
|
||||
_, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
require.NoFileExists(t, dir+"/abcdefghijkl")
|
||||
}
|
||||
|
||||
func TestFileStore_Read_NotFound(t *testing.T) {
|
||||
_, s := newTestFileStore(t)
|
||||
_, _, err := s.Read("abcdefghijkl")
|
||||
_, c := newTestFileStore(t)
|
||||
_, _, err := c.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func newTestFileStore(t *testing.T) (dir string, store Store) {
|
||||
dir = t.TempDir()
|
||||
store, err := NewFileStore(dir, 10*1024)
|
||||
func TestFileStore_Sync(t *testing.T) {
|
||||
dir, c := newTestFileStore(t)
|
||||
|
||||
// Write some files
|
||||
_, err := c.Write("abcdefghijk0", strings.NewReader("file0"))
|
||||
require.Nil(t, err)
|
||||
return dir, store
|
||||
_, err = c.Write("abcdefghijk1", strings.NewReader("file1"))
|
||||
require.Nil(t, err)
|
||||
_, err = c.Write("abcdefghijk2", strings.NewReader("file2"))
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, int64(15), c.Size())
|
||||
|
||||
// Set the ID provider to only know about file 0 and 2
|
||||
c.localIDs = func() ([]string, error) {
|
||||
return []string{"abcdefghijk0", "abcdefghijk2"}, nil
|
||||
}
|
||||
|
||||
// Make file 1's mod time old enough to be cleaned up (> 1 hour)
|
||||
oldTime := time.Unix(1, 0)
|
||||
os.Chtimes(dir+"/abcdefghijk1", oldTime, oldTime)
|
||||
|
||||
// Run sync
|
||||
require.Nil(t, c.sync())
|
||||
|
||||
// File 1 should be deleted (orphan, old enough)
|
||||
require.NoFileExists(t, dir+"/abcdefghijk1")
|
||||
require.FileExists(t, dir+"/abcdefghijk0")
|
||||
require.FileExists(t, dir+"/abcdefghijk2")
|
||||
|
||||
// Size should be updated
|
||||
require.Equal(t, int64(10), c.Size())
|
||||
}
|
||||
|
||||
func TestFileStore_Sync_SkipsRecentFiles(t *testing.T) {
|
||||
dir, c := newTestFileStore(t)
|
||||
|
||||
// Write a file
|
||||
_, err := c.Write("abcdefghijk0", strings.NewReader("file0"))
|
||||
require.Nil(t, err)
|
||||
|
||||
// Set the ID provider to return empty (no valid IDs)
|
||||
c.localIDs = func() ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// File was just created, so it should NOT be deleted (< 1 hour old)
|
||||
require.Nil(t, c.sync())
|
||||
require.FileExists(t, dir+"/abcdefghijk0")
|
||||
}
|
||||
|
||||
func newTestFileStore(t *testing.T) (dir string, cache *Store) {
|
||||
t.Helper()
|
||||
dir = t.TempDir()
|
||||
cache, err := NewFileStore(dir, 10*1024, nil)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { cache.Close() })
|
||||
return dir, cache
|
||||
}
|
||||
|
||||
func readFile(t *testing.T, f string) string {
|
||||
t.Helper()
|
||||
b, err := os.ReadFile(f)
|
||||
require.Nil(t, err)
|
||||
return string(b)
|
||||
|
|
|
|||
|
|
@ -1,150 +0,0 @@
|
|||
package attachment
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
"heckel.io/ntfy/v2/util"
|
||||
)
|
||||
|
||||
const (
|
||||
tagS3Store = "s3_store"
|
||||
)
|
||||
|
||||
type s3Store struct {
|
||||
client *s3.Client
|
||||
totalSizeCurrent int64
|
||||
totalSizeLimit int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format:
|
||||
//
|
||||
// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]
|
||||
func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) {
|
||||
cfg, err := s3.ParseURL(s3URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
store := &s3Store{
|
||||
client: s3.New(cfg),
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}
|
||||
if totalSizeLimit > 0 {
|
||||
size, err := store.computeSize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err)
|
||||
}
|
||||
store.totalSizeCurrent = size
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return 0, errInvalidFileID
|
||||
}
|
||||
log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3")
|
||||
|
||||
// Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked
|
||||
// uploads, so no temp file or Content-Length is needed.
|
||||
limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
|
||||
pr, pw := io.Pipe()
|
||||
lw := util.NewLimitWriter(pw, limiters...)
|
||||
var size int64
|
||||
var copyErr error
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
size, copyErr = io.Copy(lw, in)
|
||||
if copyErr != nil {
|
||||
pw.CloseWithError(copyErr)
|
||||
} else {
|
||||
pw.Close()
|
||||
}
|
||||
}()
|
||||
putErr := c.client.PutObject(context.Background(), id, pr)
|
||||
pr.Close()
|
||||
<-done
|
||||
if copyErr != nil {
|
||||
return 0, copyErr
|
||||
}
|
||||
if putErr != nil {
|
||||
return 0, putErr
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent += size
|
||||
c.mu.Unlock()
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return nil, 0, errInvalidFileID
|
||||
}
|
||||
return c.client.GetObject(context.Background(), id)
|
||||
}
|
||||
|
||||
func (c *s3Store) Remove(ids ...string) error {
|
||||
for _, id := range ids {
|
||||
if !fileIDRegex.MatchString(id) {
|
||||
return errInvalidFileID
|
||||
}
|
||||
}
|
||||
// S3 DeleteObjects supports up to 1000 keys per call
|
||||
for i := 0; i < len(ids); i += 1000 {
|
||||
end := i + 1000
|
||||
if end > len(ids) {
|
||||
end = len(ids)
|
||||
}
|
||||
batch := ids[i:end]
|
||||
for _, id := range batch {
|
||||
log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3")
|
||||
}
|
||||
if err := c.client.DeleteObjects(context.Background(), batch); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern)
|
||||
size, err := c.computeSize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3 store: failed to compute size after remove: %w", err)
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.totalSizeCurrent = size
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *s3Store) Size() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.totalSizeCurrent
|
||||
}
|
||||
|
||||
func (c *s3Store) Remaining() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
remaining := c.totalSizeLimit - c.totalSizeCurrent
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix.
|
||||
func (c *s3Store) computeSize() (int64, error) {
|
||||
objects, err := c.client.ListAllObjects(context.Background())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var totalSize int64
|
||||
for _, obj := range objects {
|
||||
totalSize += obj.Size
|
||||
}
|
||||
return totalSize, nil
|
||||
}
|
||||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"heckel.io/ntfy/v2/s3"
|
||||
|
|
@ -22,16 +23,16 @@ func TestS3Store_WriteReadRemove(t *testing.T) {
|
|||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
// Write
|
||||
size, err := store.Write("abcdefghijkl", strings.NewReader("hello world"))
|
||||
size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), size)
|
||||
require.Equal(t, int64(11), store.Size())
|
||||
require.Equal(t, int64(11), cache.Size())
|
||||
|
||||
// Read back
|
||||
reader, readSize, err := store.Read("abcdefghijkl")
|
||||
reader, readSize, err := cache.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(11), readSize)
|
||||
data, err := io.ReadAll(reader)
|
||||
|
|
@ -40,11 +41,11 @@ func TestS3Store_WriteReadRemove(t *testing.T) {
|
|||
require.Equal(t, "hello world", string(data))
|
||||
|
||||
// Remove
|
||||
require.Nil(t, store.Remove("abcdefghijkl"))
|
||||
require.Equal(t, int64(0), store.Size())
|
||||
require.Nil(t, cache.Remove("abcdefghijkl"))
|
||||
// Size is not recomputed by Remove; stays stale until next sync
|
||||
|
||||
// Read after remove should fail
|
||||
_, _, err = store.Read("abcdefghijkl")
|
||||
_, _, err = cache.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -52,13 +53,13 @@ func TestS3Store_WriteNoPrefix(t *testing.T) {
|
|||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "", 10*1024)
|
||||
cache := newTestS3Store(t, server, "my-bucket", "", 10*1024)
|
||||
|
||||
size, err := store.Write("abcdefghijkl", strings.NewReader("test"))
|
||||
size, err := cache.Write("abcdefghijkl", strings.NewReader("test"))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(4), size)
|
||||
|
||||
reader, _, err := store.Read("abcdefghijkl")
|
||||
reader, _, err := cache.Read("abcdefghijkl")
|
||||
require.Nil(t, err)
|
||||
data, err := io.ReadAll(reader)
|
||||
reader.Close()
|
||||
|
|
@ -70,52 +71,53 @@ func TestS3Store_WriteTotalSizeLimit(t *testing.T) {
|
|||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 100)
|
||||
cache := newTestS3Store(t, server, "my-bucket", "pfx", 100)
|
||||
|
||||
// First write fits
|
||||
_, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
|
||||
_, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)))
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, int64(80), store.Size())
|
||||
require.Equal(t, int64(20), store.Remaining())
|
||||
require.Equal(t, int64(80), cache.Size())
|
||||
require.Equal(t, int64(20), cache.Remaining())
|
||||
|
||||
// Second write exceeds total limit
|
||||
_, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
_, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)))
|
||||
require.ErrorIs(t, err, util.ErrLimitReached)
|
||||
}
|
||||
|
||||
func TestS3Store_WriteFileSizeLimit(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
|
||||
require.Equal(t, util.ErrLimitReached, err)
|
||||
_, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100))
|
||||
require.ErrorIs(t, err, util.ErrLimitReached)
|
||||
}
|
||||
|
||||
func TestS3Store_WriteRemoveMultiple(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
|
||||
_, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)))
|
||||
require.Nil(t, err)
|
||||
}
|
||||
require.Equal(t, int64(500), store.Size())
|
||||
require.Equal(t, int64(500), cache.Size())
|
||||
|
||||
require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3"))
|
||||
require.Equal(t, int64(300), store.Size())
|
||||
require.Nil(t, cache.Remove("abcdefghijk1", "abcdefghijk3"))
|
||||
// Size not recomputed by Remove
|
||||
require.Equal(t, int64(500), cache.Size())
|
||||
}
|
||||
|
||||
func TestS3Store_ReadNotFound(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, _, err := store.Read("abcdefghijkl")
|
||||
_, _, err := cache.Read("abcdefghijkl")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
|
|
@ -123,42 +125,93 @@ func TestS3Store_InvalidID(t *testing.T) {
|
|||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, err := store.Write("bad", strings.NewReader("x"))
|
||||
_, err := cache.Write("bad", strings.NewReader("x"))
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
|
||||
_, _, err = store.Read("bad")
|
||||
_, _, err = cache.Read("bad")
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
|
||||
err = store.Remove("bad")
|
||||
err = cache.Remove("bad")
|
||||
require.Equal(t, errInvalidFileID, err)
|
||||
}
|
||||
|
||||
func TestS3Store_Sync(t *testing.T) {
|
||||
server := newMockS3Server()
|
||||
defer server.Close()
|
||||
|
||||
cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
// Write some files
|
||||
_, err := cache.Write("abcdefghijk0", strings.NewReader("file0"))
|
||||
require.Nil(t, err)
|
||||
_, err = cache.Write("abcdefghijk1", strings.NewReader("file1"))
|
||||
require.Nil(t, err)
|
||||
_, err = cache.Write("abcdefghijk2", strings.NewReader("file2"))
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Equal(t, int64(15), cache.Size())
|
||||
|
||||
// Set the ID provider to only know about file 0 and 2
|
||||
// All mock objects have LastModified set to 2 hours ago, so orphans are eligible for deletion
|
||||
cache.localIDs = func() ([]string, error) {
|
||||
return []string{"abcdefghijk0", "abcdefghijk2"}, nil
|
||||
}
|
||||
|
||||
// Run sync
|
||||
require.Nil(t, cache.sync())
|
||||
|
||||
// File 1 should be deleted (orphan)
|
||||
_, _, err = cache.Read("abcdefghijk1")
|
||||
require.Error(t, err)
|
||||
|
||||
// Size should be updated
|
||||
require.Equal(t, int64(10), cache.Size())
|
||||
}
|
||||
|
||||
func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) {
|
||||
mockServer := newMockS3ServerWithModTime(time.Now())
|
||||
defer mockServer.Close()
|
||||
|
||||
cache := newTestS3Store(t, mockServer, "my-bucket", "pfx", 10*1024)
|
||||
|
||||
_, err := cache.Write("abcdefghijk0", strings.NewReader("file0"))
|
||||
require.Nil(t, err)
|
||||
|
||||
// Set the ID provider to return empty (no valid IDs)
|
||||
cache.localIDs = func() ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// File was "just created" (mock returns recent time), so it should NOT be deleted
|
||||
require.Nil(t, cache.sync())
|
||||
|
||||
// File should still exist
|
||||
reader, _, err := cache.Read("abcdefghijk0")
|
||||
require.Nil(t, err)
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store {
|
||||
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store {
|
||||
t.Helper()
|
||||
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
|
||||
host := strings.TrimPrefix(server.URL, "https://")
|
||||
s := &s3Store{
|
||||
client: &s3.Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: host,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: true,
|
||||
HTTPClient: server.Client(),
|
||||
},
|
||||
totalSizeLimit: totalSizeLimit,
|
||||
}
|
||||
// Compute initial size (should be 0 for fresh mock)
|
||||
size, err := s.computeSize()
|
||||
backend := newS3Backend(&s3.Client{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: host,
|
||||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: true,
|
||||
HTTPClient: server.Client(),
|
||||
})
|
||||
cache, err := newStore(backend, totalSizeLimit, nil)
|
||||
require.Nil(t, err)
|
||||
s.totalSizeCurrent = size
|
||||
return s
|
||||
t.Cleanup(func() { cache.Close() })
|
||||
return cache
|
||||
}
|
||||
|
||||
// --- Mock S3 server ---
|
||||
|
|
@ -167,16 +220,22 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string
|
|||
// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory.
|
||||
|
||||
type mockS3Server struct {
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
|
||||
nextID int // counter for generating upload IDs
|
||||
mu sync.RWMutex
|
||||
objects map[string][]byte // full key (bucket/key) -> body
|
||||
uploads map[string]map[int][]byte // uploadID -> partNumber -> data
|
||||
nextID int // counter for generating upload IDs
|
||||
lastModTime time.Time // time to return for LastModified in list responses
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockS3Server() *httptest.Server {
|
||||
return newMockS3ServerWithModTime(time.Now().Add(-2 * time.Hour))
|
||||
}
|
||||
|
||||
func newMockS3ServerWithModTime(modTime time.Time) *httptest.Server {
|
||||
m := &mockS3Server{
|
||||
objects: make(map[string][]byte),
|
||||
uploads: make(map[string]map[int][]byte),
|
||||
objects: make(map[string][]byte),
|
||||
uploads: make(map[string]map[int][]byte),
|
||||
lastModTime: modTime,
|
||||
}
|
||||
return httptest.NewTLSServer(m)
|
||||
}
|
||||
|
|
@ -341,7 +400,11 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket
|
|||
continue // different bucket
|
||||
}
|
||||
if prefix == "" || strings.HasPrefix(objKey, prefix) {
|
||||
contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))})
|
||||
contents = append(contents, s3ListObject{
|
||||
Key: objKey,
|
||||
Size: int64(len(body)),
|
||||
LastModified: m.lastModTime.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
|
@ -362,6 +425,7 @@ type s3ListResponse struct {
|
|||
}
|
||||
|
||||
type s3ListObject struct {
|
||||
Key string `xml:"Key"`
|
||||
Size int64 `xml:"Size"`
|
||||
Key string `xml:"Key"`
|
||||
Size int64 `xml:"Size"`
|
||||
LastModified string `xml:"LastModified"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ type queries struct {
|
|||
selectStats string
|
||||
updateStats string
|
||||
updateMessageTime string
|
||||
selectAttachmentIDs string
|
||||
}
|
||||
|
||||
// Cache stores published messages
|
||||
|
|
@ -252,18 +253,7 @@ func (c *Cache) MessagesExpired() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
return readStrings(rows)
|
||||
}
|
||||
|
||||
// Message returns the message with the given ID, or ErrMessageNotFound if not found
|
||||
|
|
@ -319,18 +309,7 @@ func (c *Cache) Topics() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
topics := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topics = append(topics, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics, nil
|
||||
return readStrings(rows)
|
||||
}
|
||||
|
||||
// DeleteMessages deletes the messages with the given IDs
|
||||
|
|
@ -358,15 +337,8 @@ func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string,
|
|||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
ids, err := readStrings(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows.Close() // Close rows before executing delete in same transaction
|
||||
|
|
@ -391,6 +363,16 @@ func (c *Cache) ExpireMessages(topics ...string) error {
|
|||
})
|
||||
}
|
||||
|
||||
// AttachmentIDs returns message IDs with active (non-expired, non-deleted) attachments
|
||||
func (c *Cache) AttachmentIDs() ([]string, error) {
|
||||
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentIDs, time.Now().Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return readStrings(rows)
|
||||
}
|
||||
|
||||
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
|
||||
func (c *Cache) AttachmentsExpired() ([]string, error) {
|
||||
rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix())
|
||||
|
|
@ -398,18 +380,7 @@ func (c *Cache) AttachmentsExpired() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
return readStrings(rows)
|
||||
}
|
||||
|
||||
// MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted
|
||||
|
|
@ -590,3 +561,18 @@ func readMessage(rows *sql.Rows) (*model.Message, error) {
|
|||
Encoding: encoding,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func readStrings(rows *sql.Rows) ([]string, error) {
|
||||
strs := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var s string
|
||||
if err := rows.Scan(&s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
strs = append(strs, s)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return strs, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,6 +74,8 @@ const (
|
|||
postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'`
|
||||
postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'`
|
||||
postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2`
|
||||
|
||||
postgresSelectAttachmentIDsQuery = `SELECT mid FROM message WHERE attachment_expires > $1 AND attachment_deleted = FALSE`
|
||||
)
|
||||
|
||||
var postgresQueries = queries{
|
||||
|
|
@ -100,6 +102,7 @@ var postgresQueries = queries{
|
|||
selectStats: postgresSelectStatsQuery,
|
||||
updateStats: postgresUpdateStatsQuery,
|
||||
updateMessageTime: postgresUpdateMessageTimeQuery,
|
||||
selectAttachmentIDs: postgresSelectAttachmentIDsQuery,
|
||||
}
|
||||
|
||||
// NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool.
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ const (
|
|||
sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
|
||||
sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
|
||||
sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?`
|
||||
|
||||
sqliteSelectAttachmentIDsQuery = `SELECT mid FROM messages WHERE attachment_expires > ? AND attachment_deleted = 0`
|
||||
)
|
||||
|
||||
var sqliteQueries = queries{
|
||||
|
|
@ -103,6 +105,7 @@ var sqliteQueries = queries{
|
|||
selectStats: sqliteSelectStatsQuery,
|
||||
updateStats: sqliteUpdateStatsQuery,
|
||||
updateMessageTime: sqliteUpdateMessageTimeQuery,
|
||||
selectAttachmentIDs: sqliteSelectAttachmentIDsQuery,
|
||||
}
|
||||
|
||||
// NewSQLiteStore creates a SQLite file-backed cache
|
||||
|
|
|
|||
10
s3/client.go
10
s3/client.go
|
|
@ -196,7 +196,15 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK
|
|||
}
|
||||
objects := make([]Object, len(result.Contents))
|
||||
for i, obj := range result.Contents {
|
||||
objects[i] = Object(obj)
|
||||
var lastModified time.Time
|
||||
if obj.LastModified != "" {
|
||||
lastModified, _ = time.Parse(time.RFC3339, obj.LastModified)
|
||||
}
|
||||
objects[i] = Object{
|
||||
Key: obj.Key,
|
||||
Size: obj.Size,
|
||||
LastModified: lastModified,
|
||||
}
|
||||
}
|
||||
return &ListResult{
|
||||
Objects: objects,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -243,7 +244,7 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket
|
|||
var contents []listObject
|
||||
for _, objKey := range allKeys {
|
||||
body := m.objects[bucketPath+"/"+objKey]
|
||||
contents = append(contents, listObject{Key: objKey, Size: int64(len(body))})
|
||||
contents = append(contents, listObject{Key: objKey, Size: int64(len(body)), LastModified: time.Now().Format(time.RFC3339)})
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
|
|
|
|||
15
s3/types.go
15
s3/types.go
|
|
@ -1,6 +1,9 @@
|
|||
package s3
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
|
||||
type Config struct {
|
||||
|
|
@ -15,8 +18,9 @@ type Config struct {
|
|||
|
||||
// Object represents an S3 object returned by list operations.
|
||||
type Object struct {
|
||||
Key string
|
||||
Size int64
|
||||
Key string
|
||||
Size int64
|
||||
LastModified time.Time
|
||||
}
|
||||
|
||||
// ListResult holds the response from a ListObjectsV2 call.
|
||||
|
|
@ -49,8 +53,9 @@ type listObjectsV2Response struct {
|
|||
}
|
||||
|
||||
type listObject struct {
|
||||
Key string `xml:"Key"`
|
||||
Size int64 `xml:"Size"`
|
||||
Key string `xml:"Key"`
|
||||
Size int64 `xml:"Size"`
|
||||
LastModified string `xml:"LastModified"`
|
||||
}
|
||||
|
||||
// deleteResult is the XML response from S3 DeleteObjects
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ type Server struct {
|
|||
userManager *user.Manager // Might be nil!
|
||||
messageCache *message.Cache // Database that stores the messages
|
||||
webPush *webpush.Store // Database that stores web push subscriptions
|
||||
fileCache attachment.Store // Attachment store (file system or S3)
|
||||
fileCache *attachment.Store // Attachment store (file system or S3)
|
||||
stripe stripeAPI // Stripe API, can be replaced with a mock
|
||||
priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
|
||||
metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set
|
||||
|
|
@ -229,7 +229,7 @@ func New(conf *Config) (*Server, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileCache, err := createAttachmentStore(conf)
|
||||
fileCache, err := createAttachmentStore(conf, messageCache)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -300,11 +300,14 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) {
|
|||
return message.NewMemStore()
|
||||
}
|
||||
|
||||
func createAttachmentStore(conf *Config) (attachment.Store, error) {
|
||||
func createAttachmentStore(conf *Config, messageCache *message.Cache) (*attachment.Store, error) {
|
||||
idProvider := func() ([]string, error) {
|
||||
return messageCache.AttachmentIDs()
|
||||
}
|
||||
if conf.AttachmentS3URL != "" {
|
||||
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit)
|
||||
return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit, idProvider)
|
||||
} else if conf.AttachmentCacheDir != "" {
|
||||
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
|
||||
return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, idProvider)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -429,6 +432,9 @@ func (s *Server) Stop() {
|
|||
if s.smtpServer != nil {
|
||||
s.smtpServer.Close()
|
||||
}
|
||||
if s.fileCache != nil {
|
||||
s.fileCache.Close()
|
||||
}
|
||||
s.closeDatabases()
|
||||
close(s.closeChan)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -152,6 +152,61 @@ func (l *RateLimiter) Reset() {
|
|||
l.value = 0
|
||||
}
|
||||
|
||||
// CountingReader wraps an io.Reader and counts the number of bytes read through it.
|
||||
type CountingReader struct {
|
||||
r io.Reader
|
||||
total int64
|
||||
}
|
||||
|
||||
// NewCountingReader creates a new CountingReader
|
||||
func NewCountingReader(r io.Reader) *CountingReader {
|
||||
return &CountingReader{r: r}
|
||||
}
|
||||
|
||||
// Read passes through to the underlying reader and counts the bytes read
|
||||
func (r *CountingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
r.total += int64(n)
|
||||
return
|
||||
}
|
||||
|
||||
// Total returns the total number of bytes read so far
|
||||
func (r *CountingReader) Total() int64 {
|
||||
return r.total
|
||||
}
|
||||
|
||||
// LimitReader implements an io.Reader that will pass through all Read calls to the underlying
|
||||
// reader r until any of the limiter's limit is reached, at which point a Read will return ErrLimitReached.
|
||||
// Each limiter's value is increased after every read based on the number of bytes actually read.
|
||||
type LimitReader struct {
|
||||
r io.Reader
|
||||
limiters []Limiter
|
||||
}
|
||||
|
||||
// NewLimitReader creates a new LimitReader
|
||||
func NewLimitReader(r io.Reader, limiters ...Limiter) *LimitReader {
|
||||
return &LimitReader{
|
||||
r: r,
|
||||
limiters: limiters,
|
||||
}
|
||||
}
|
||||
|
||||
// Read passes through all reads to the underlying reader until any of the given limiter's limit is reached
|
||||
func (r *LimitReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
if n > 0 {
|
||||
for i := 0; i < len(r.limiters); i++ {
|
||||
if !r.limiters[i].AllowN(int64(n)) {
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
r.limiters[j].AllowN(-int64(n)) // Revert limiters if not allowed
|
||||
}
|
||||
return 0, ErrLimitReached
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
|
||||
// writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
|
||||
// Each limiter's value is increased with every write.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue