From d517ce4a2aa21ef8379a8dc013fc6e58cdebbebe Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 14 Mar 2026 21:10:46 -0400 Subject: [PATCH 01/32] WIP: S3 --- attachment/store.go | 25 ++ .../file_cache.go => attachment/store_file.go | 57 ++-- attachment/store_file_test.go | 99 +++++++ attachment/store_s3.go | 255 ++++++++++++++++++ attachment/store_s3_test.go | 76 ++++++ cmd/serve.go | 7 + docs/config.md | 45 +++- docs/releases.md | 1 + go.mod | 12 + go.sum | 24 ++ server/config.go | 1 + server/file_cache_test.go | 76 ------ server/log.go | 1 - server/server.go | 42 +-- server/server.yml | 1 + server/server_manager.go | 3 + 16 files changed, 597 insertions(+), 128 deletions(-) create mode 100644 attachment/store.go rename server/file_cache.go => attachment/store_file.go (66%) create mode 100644 attachment/store_file_test.go create mode 100644 attachment/store_s3.go create mode 100644 attachment/store_s3_test.go delete mode 100644 server/file_cache_test.go diff --git a/attachment/store.go b/attachment/store.go new file mode 100644 index 00000000..c48a1e90 --- /dev/null +++ b/attachment/store.go @@ -0,0 +1,25 @@ +package attachment + +import ( + "errors" + "fmt" + "io" + "regexp" + + "heckel.io/ntfy/v2/model" + "heckel.io/ntfy/v2/util" +) + +// Store is an interface for storing and retrieving attachment files +type Store interface { + Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) + Read(id string) (io.ReadCloser, int64, error) + Remove(ids ...string) error + Size() int64 + Remaining() int64 +} + +var ( + fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength)) + errInvalidFileID = errors.New("invalid file ID") +) diff --git a/server/file_cache.go b/attachment/store_file.go similarity index 66% rename from server/file_cache.go rename to attachment/store_file.go index a1803724..b26e86f0 100644 --- a/server/file_cache.go +++ b/attachment/store_file.go @@ -1,32 +1,29 @@ -package server +package attachment import ( "errors" - "fmt" - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/model" - "heckel.io/ntfy/v2/util" "io" "os" "path/filepath" - "regexp" "sync" + + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/util" ) -var ( - fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength)) - errInvalidFileID = errors.New("invalid file ID") - errFileExists = errors.New("file exists") -) +const tagFileStore = "file_store" -type fileCache struct { +var errFileExists = errors.New("file exists") + +type fileStore struct { dir string totalSizeCurrent int64 totalSizeLimit int64 mu sync.Mutex } -func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) { +// NewFileStore creates a new file-system backed attachment store +func NewFileStore(dir string, totalSizeLimit int64) (Store, error) { if err := os.MkdirAll(dir, 0700); err != nil { return nil, err } @@ -34,18 +31,18 @@ func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) { if err != nil { return nil, err } - return &fileCache{ + return &fileStore{ dir: dir, totalSizeCurrent: size, totalSizeLimit: totalSizeLimit, }, nil } -func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { +func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { if !fileIDRegex.MatchString(id) { return 0, errInvalidFileID } - log.Tag(tagFileCache).Field("message_id", id).Debug("Writing attachment") + log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment") file := filepath.Join(c.dir, id) if _, err := os.Stat(file); err == nil { return 0, errFileExists @@ -68,20 +65,35 @@ func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (in } c.mu.Lock() c.totalSizeCurrent += size - mset(metricAttachmentsTotalSize, c.totalSizeCurrent) c.mu.Unlock() return size, nil } -func (c *fileCache) Remove(ids ...string) error { +func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) { + if !fileIDRegex.MatchString(id) { + return nil, 0, errInvalidFileID + } + file := filepath.Join(c.dir, id) + stat, err := os.Stat(file) + if err != nil { + return nil, 0, err + } + f, err := os.Open(file) + if err != nil { + return nil, 0, err + } + return f, stat.Size(), nil +} + +func (c *fileStore) Remove(ids ...string) error { for _, id := range ids { if !fileIDRegex.MatchString(id) { return errInvalidFileID } - log.Tag(tagFileCache).Field("message_id", id).Debug("Deleting attachment") + log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment") file := filepath.Join(c.dir, id) if err := os.Remove(file); err != nil { - log.Tag(tagFileCache).Field("message_id", id).Err(err).Debug("Error deleting attachment") + log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment") } } size, err := dirSize(c.dir) @@ -91,17 +103,16 @@ func (c *fileCache) Remove(ids ...string) error { c.mu.Lock() c.totalSizeCurrent = size c.mu.Unlock() - mset(metricAttachmentsTotalSize, size) return nil } -func (c *fileCache) Size() int64 { +func (c *fileStore) Size() int64 { c.mu.Lock() defer c.mu.Unlock() return c.totalSizeCurrent } -func (c *fileCache) Remaining() int64 { +func (c *fileStore) Remaining() int64 { c.mu.Lock() defer c.mu.Unlock() remaining := c.totalSizeLimit - c.totalSizeCurrent diff --git a/attachment/store_file_test.go b/attachment/store_file_test.go new file mode 100644 index 00000000..5cfe0db4 --- /dev/null +++ b/attachment/store_file_test.go @@ -0,0 +1,99 @@ +package attachment + +import ( + "bytes" + "fmt" + "io" + "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/util" +) + +var ( + oneKilobyteArray = make([]byte, 1024) +) + +func TestFileStore_Write_Success(t *testing.T) { + dir, s := newTestFileStore(t) + size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) + require.Nil(t, err) + require.Equal(t, int64(11), size) + require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) + require.Equal(t, int64(11), s.Size()) + require.Equal(t, int64(10229), s.Remaining()) +} + +func TestFileStore_Write_Read_Success(t *testing.T) { + _, s := newTestFileStore(t) + size, err := s.Write("abcdefghijkl", strings.NewReader("hello world")) + require.Nil(t, err) + require.Equal(t, int64(11), size) + + reader, readSize, err := s.Read("abcdefghijkl") + require.Nil(t, err) + require.Equal(t, int64(11), readSize) + defer reader.Close() + data, err := io.ReadAll(reader) + require.Nil(t, err) + require.Equal(t, "hello world", string(data)) +} + +func TestFileStore_Write_Remove_Success(t *testing.T) { + dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024) + for i := 0; i < 10; i++ { // 10x999 = 9990 + size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) + require.Nil(t, err) + require.Equal(t, int64(999), size) + } + require.Equal(t, int64(9990), s.Size()) + require.Equal(t, int64(250), s.Remaining()) + require.FileExists(t, dir+"/abcdefghijk1") + require.FileExists(t, dir+"/abcdefghijk5") + + require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5")) + require.NoFileExists(t, dir+"/abcdefghijk1") + require.NoFileExists(t, dir+"/abcdefghijk5") + require.Equal(t, int64(7992), s.Size()) + require.Equal(t, int64(2248), s.Remaining()) +} + +func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) { + dir, s := newTestFileStore(t) + for i := 0; i < 10; i++ { + size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) + require.Nil(t, err) + require.Equal(t, int64(1024), size) + } + _, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) + require.Equal(t, util.ErrLimitReached, err) + require.NoFileExists(t, dir+"/abcdefghijkX") +} + +func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) { + dir, s := newTestFileStore(t) + _, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) + require.Equal(t, util.ErrLimitReached, err) + require.NoFileExists(t, dir+"/abcdefghijkl") +} + +func TestFileStore_Read_NotFound(t *testing.T) { + _, s := newTestFileStore(t) + _, _, err := s.Read("abcdefghijkl") + require.Error(t, err) +} + +func newTestFileStore(t *testing.T) (dir string, store Store) { + dir = t.TempDir() + store, err := NewFileStore(dir, 10*1024) + require.Nil(t, err) + return dir, store +} + +func readFile(t *testing.T, f string) string { + b, err := os.ReadFile(f) + require.Nil(t, err) + return string(b) +} diff --git a/attachment/store_s3.go b/attachment/store_s3.go new file mode 100644 index 00000000..118da4ce --- /dev/null +++ b/attachment/store_s3.go @@ -0,0 +1,255 @@ +package attachment + +import ( + "context" + "fmt" + "io" + "net/url" + "strings" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" + "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/util" +) + +const tagS3Store = "s3_store" + +type s3Store struct { + client *s3.Client + bucket string + prefix string + totalSizeCurrent int64 + totalSizeLimit int64 + mu sync.Mutex +} + +// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format: +// +// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] +func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) { + bucket, prefix, client, err := parseS3URL(s3URL) + if err != nil { + return nil, err + } + store := &s3Store{ + client: client, + bucket: bucket, + prefix: prefix, + totalSizeLimit: totalSizeLimit, + } + if totalSizeLimit > 0 { + size, err := store.computeSize() + if err != nil { + return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err) + } + store.totalSizeCurrent = size + } + return store, nil +} + +func parseS3URL(s3URL string) (bucket string, prefix string, client *s3.Client, err error) { + u, err := url.Parse(s3URL) + if err != nil { + return "", "", nil, fmt.Errorf("s3 store: invalid URL: %w", err) + } + if u.Scheme != "s3" { + return "", "", nil, fmt.Errorf("s3 store: URL scheme must be 's3', got '%s'", u.Scheme) + } + if u.Host == "" { + return "", "", nil, fmt.Errorf("s3 store: bucket name must be specified as host") + } + bucket = u.Host + prefix = strings.TrimPrefix(u.Path, "/") + + accessKey := u.User.Username() + secretKey, _ := u.User.Password() + if accessKey == "" || secretKey == "" { + return "", "", nil, fmt.Errorf("s3 store: access key and secret key must be specified in URL") + } + + region := u.Query().Get("region") + if region == "" { + return "", "", nil, fmt.Errorf("s3 store: region query parameter is required") + } + endpoint := u.Query().Get("endpoint") + + cfg := aws.Config{ + Region: region, + Credentials: credentials.NewStaticCredentialsProvider(accessKey, secretKey, ""), + } + var opts []func(*s3.Options) + if endpoint != "" { + opts = append(opts, func(o *s3.Options) { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + }) + } + client = s3.NewFromConfig(cfg, opts...) + return bucket, prefix, client, nil +} + +func (c *s3Store) objectKey(id string) string { + if c.prefix != "" { + return c.prefix + "/" + id + } + return id +} + +func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { + if !fileIDRegex.MatchString(id) { + return 0, errInvalidFileID + } + log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3") + + // Use io.Pipe so we can apply limiters while streaming to S3 + pr, pw := io.Pipe() + var writeErr error + var size int64 + + limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) + go func() { + limitWriter := util.NewLimitWriter(pw, limiters...) + size, writeErr = io.Copy(limitWriter, in) + if writeErr != nil { + pw.CloseWithError(writeErr) + } else { + pw.Close() + } + }() + + key := c.objectKey(id) + _, err := c.client.PutObject(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String(c.bucket), + Key: aws.String(key), + Body: pr, + }) + if err != nil { + // If the limiter caused the error, return the original write error + if writeErr != nil { + return 0, writeErr + } + return 0, fmt.Errorf("s3 store: PutObject failed: %w", err) + } + if writeErr != nil { + // The write goroutine failed but PutObject somehow succeeded; clean up + _, _ = c.client.DeleteObject(context.Background(), &s3.DeleteObjectInput{ + Bucket: aws.String(c.bucket), + Key: aws.String(key), + }) + return 0, writeErr + } + + c.mu.Lock() + c.totalSizeCurrent += size + c.mu.Unlock() + return size, nil +} + +func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) { + if !fileIDRegex.MatchString(id) { + return nil, 0, errInvalidFileID + } + key := c.objectKey(id) + resp, err := c.client.GetObject(context.Background(), &s3.GetObjectInput{ + Bucket: aws.String(c.bucket), + Key: aws.String(key), + }) + if err != nil { + return nil, 0, fmt.Errorf("s3 store: GetObject failed: %w", err) + } + var size int64 + if resp.ContentLength != nil { + size = *resp.ContentLength + } + return resp.Body, size, nil +} + +func (c *s3Store) Remove(ids ...string) error { + for _, id := range ids { + if !fileIDRegex.MatchString(id) { + return errInvalidFileID + } + } + // S3 DeleteObjects supports up to 1000 keys per call + for i := 0; i < len(ids); i += 1000 { + end := i + 1000 + if end > len(ids) { + end = len(ids) + } + batch := ids[i:end] + objects := make([]s3types.ObjectIdentifier, len(batch)) + for j, id := range batch { + log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3") + key := c.objectKey(id) + objects[j] = s3types.ObjectIdentifier{ + Key: aws.String(key), + } + } + _, err := c.client.DeleteObjects(context.Background(), &s3.DeleteObjectsInput{ + Bucket: aws.String(c.bucket), + Delete: &s3types.Delete{ + Objects: objects, + Quiet: aws.Bool(true), + }, + }) + if err != nil { + return fmt.Errorf("s3 store: DeleteObjects failed: %w", err) + } + } + // Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern) + size, err := c.computeSize() + if err != nil { + return fmt.Errorf("s3 store: failed to compute size after remove: %w", err) + } + c.mu.Lock() + c.totalSizeCurrent = size + c.mu.Unlock() + return nil +} + +func (c *s3Store) Size() int64 { + c.mu.Lock() + defer c.mu.Unlock() + return c.totalSizeCurrent +} + +func (c *s3Store) Remaining() int64 { + c.mu.Lock() + defer c.mu.Unlock() + remaining := c.totalSizeLimit - c.totalSizeCurrent + if remaining < 0 { + return 0 + } + return remaining +} + +func (c *s3Store) computeSize() (int64, error) { + var size int64 + paginator := s3.NewListObjectsV2Paginator(c.client, &s3.ListObjectsV2Input{ + Bucket: aws.String(c.bucket), + Prefix: aws.String(c.prefixForList()), + }) + for paginator.HasMorePages() { + page, err := paginator.NextPage(context.Background()) + if err != nil { + return 0, err + } + for _, obj := range page.Contents { + if obj.Size != nil { + size += *obj.Size + } + } + } + return size, nil +} + +func (c *s3Store) prefixForList() string { + if c.prefix != "" { + return c.prefix + "/" + } + return "" +} diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go new file mode 100644 index 00000000..a1808d0c --- /dev/null +++ b/attachment/store_s3_test.go @@ -0,0 +1,76 @@ +package attachment + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseS3URL_Success(t *testing.T) { + bucket, prefix, client, err := parseS3URL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1") + require.Nil(t, err) + require.Equal(t, "my-bucket", bucket) + require.Equal(t, "attachments", prefix) + require.NotNil(t, client) +} + +func TestParseS3URL_NoPrefix(t *testing.T) { + bucket, prefix, client, err := parseS3URL("s3://AKID:SECRET@my-bucket?region=us-east-1") + require.Nil(t, err) + require.Equal(t, "my-bucket", bucket) + require.Equal(t, "", prefix) + require.NotNil(t, client) +} + +func TestParseS3URL_WithEndpoint(t *testing.T) { + bucket, prefix, client, err := parseS3URL("s3://AKID:SECRET@my-bucket/prefix?region=us-east-1&endpoint=https://s3.example.com") + require.Nil(t, err) + require.Equal(t, "my-bucket", bucket) + require.Equal(t, "prefix", prefix) + require.NotNil(t, client) +} + +func TestParseS3URL_NestedPrefix(t *testing.T) { + bucket, prefix, _, err := parseS3URL("s3://AKID:SECRET@my-bucket/a/b/c?region=us-east-1") + require.Nil(t, err) + require.Equal(t, "my-bucket", bucket) + require.Equal(t, "a/b/c", prefix) +} + +func TestParseS3URL_MissingRegion(t *testing.T) { + _, _, _, err := parseS3URL("s3://AKID:SECRET@my-bucket") + require.Error(t, err) + require.Contains(t, err.Error(), "region") +} + +func TestParseS3URL_MissingCredentials(t *testing.T) { + _, _, _, err := parseS3URL("s3://my-bucket?region=us-east-1") + require.Error(t, err) + require.Contains(t, err.Error(), "access key") +} + +func TestParseS3URL_MissingSecretKey(t *testing.T) { + _, _, _, err := parseS3URL("s3://AKID@my-bucket?region=us-east-1") + require.Error(t, err) + require.Contains(t, err.Error(), "secret key") +} + +func TestParseS3URL_WrongScheme(t *testing.T) { + _, _, _, err := parseS3URL("http://AKID:SECRET@my-bucket?region=us-east-1") + require.Error(t, err) + require.Contains(t, err.Error(), "scheme") +} + +func TestParseS3URL_EmptyBucket(t *testing.T) { + _, _, _, err := parseS3URL("s3://AKID:SECRET@?region=us-east-1") + require.Error(t, err) + require.Contains(t, err.Error(), "bucket") +} + +func TestS3Store_ObjectKey(t *testing.T) { + s := &s3Store{prefix: "attachments"} + require.Equal(t, "attachments/abcdefghijkl", s.objectKey("abcdefghijkl")) + + s2 := &s3Store{prefix: ""} + require.Equal(t, "abcdefghijkl", s2.objectKey("abcdefghijkl")) +} diff --git a/cmd/serve.go b/cmd/serve.go index 415868fc..c2ed210a 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -53,6 +53,7 @@ var flagsServe = append( altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}), altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-s3-url", Aliases: []string{"attachment_s3_url"}, EnvVars: []string{"NTFY_ATTACHMENT_S3_URL"}, Usage: "S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-expiry-duration", Aliases: []string{"attachment_expiry_duration", "X"}, EnvVars: []string{"NTFY_ATTACHMENT_EXPIRY_DURATION"}, Value: util.FormatDuration(server.DefaultAttachmentExpiryDuration), Usage: "duration after which uploaded attachments will be deleted (e.g. 3h, 20h)"}), @@ -166,6 +167,7 @@ func execServe(c *cli.Context) error { authAccessRaw := c.StringSlice("auth-access") authTokensRaw := c.StringSlice("auth-tokens") attachmentCacheDir := c.String("attachment-cache-dir") + attachmentS3URL := c.String("attachment-s3-url") attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit") attachmentFileSizeLimitStr := c.String("attachment-file-size-limit") attachmentExpiryDurationStr := c.String("attachment-expiry-duration") @@ -314,6 +316,10 @@ func execServe(c *cli.Context) error { return errors.New("if smtp-server-listen is set, smtp-server-domain must also be set") } else if attachmentCacheDir != "" && baseURL == "" { return errors.New("if attachment-cache-dir is set, base-url must also be set") + } else if attachmentS3URL != "" && baseURL == "" { + return errors.New("if attachment-s3-url is set, base-url must also be set") + } else if attachmentS3URL != "" && attachmentCacheDir != "" { + return errors.New("attachment-cache-dir and attachment-s3-url are mutually exclusive") } else if baseURL != "" { u, err := url.Parse(baseURL) if err != nil { @@ -457,6 +463,7 @@ func execServe(c *cli.Context) error { conf.AuthAccess = authAccess conf.AuthTokens = authTokens conf.AttachmentCacheDir = attachmentCacheDir + conf.AttachmentS3URL = attachmentS3URL conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit conf.AttachmentFileSizeLimit = attachmentFileSizeLimit conf.AttachmentExpiryDuration = attachmentExpiryDuration diff --git a/docs/config.md b/docs/config.md index b9c8f07f..34484a51 100644 --- a/docs/config.md +++ b/docs/config.md @@ -489,20 +489,23 @@ Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscri ## Attachments If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable -this feature, you have to simply configure an attachment cache directory and a base URL (`attachment-cache-dir`, `base-url`). -Once these options are set and the directory is writable by the server user, you can upload attachments via PUT. +this feature, you have to configure an attachment storage backend and a base URL (`base-url`). Attachments can be stored +either on the local filesystem (`attachment-cache-dir`) or in an S3-compatible object store (`attachment-s3-url`). +Once configured, you can upload attachments via PUT. -By default, attachments are stored in the disk-cache **for only 3 hours**. The main reason for this is to avoid legal issues -and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download +By default, attachments are stored **for only 3 hours**. The main reason for this is to avoid legal issues +and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download feature) to download the file. The following config options are relevant to attachments: * `base-url` is the root URL for the ntfy server; this is needed for the generated attachment URLs -* `attachment-cache-dir` is the cache directory for attached files -* `attachment-total-size-limit` is the size limit of the on-disk attachment cache (default: 5G) +* `attachment-cache-dir` is the cache directory for attached files (mutually exclusive with `attachment-s3-url`) +* `attachment-s3-url` is the S3 URL for attachment storage (mutually exclusive with `attachment-cache-dir`) +* `attachment-total-size-limit` is the size limit of the attachment storage (default: 5G) * `attachment-file-size-limit` is the per-file attachment size limit (e.g. 300k, 2M, 100M, default: 15M) * `attachment-expiry-duration` is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h, default: 3h) -Here's an example config using mostly the defaults (except for the cache directory, which is empty by default): +### Filesystem storage +Here's an example config using the local filesystem for attachment storage: === "/etc/ntfy/server.yml (minimal)" ``` yaml @@ -521,6 +524,30 @@ Here's an example config using mostly the defaults (except for the cache directo visitor-attachment-daily-bandwidth-limit: "500M" ``` +### S3 storage +As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3, +MinIO, DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage. + +The `attachment-s3-url` option uses the following format: + +``` +s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] +``` + +When `endpoint` is specified, path-style addressing is enabled automatically (useful for MinIO and other S3-compatible stores). + +=== "/etc/ntfy/server.yml (AWS S3)" + ``` yaml + base-url: "https://ntfy.sh" + attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1" + ``` + +=== "/etc/ntfy/server.yml (MinIO/custom endpoint)" + ``` yaml + base-url: "https://ntfy.sh" + attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com" + ``` + Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit` and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse. @@ -2116,7 +2143,8 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`). | `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) | | `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) | | `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header | -| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. To enable attachments, this has to be set. | +| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. Mutually exclusive with `attachment-s3-url`. | +| `attachment-s3-url` | `NTFY_ATTACHMENT_S3_URL` | *URL* | - | S3 URL for attachment storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). Mutually exclusive with `attachment-cache-dir`. | | `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. | | `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. | | `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. | @@ -2219,6 +2247,7 @@ OPTIONS: --auth-startup-queries value, --auth_startup_queries value queries run when the auth database is initialized [$NTFY_AUTH_STARTUP_QUERIES] --auth-default-access value, --auth_default_access value, -p value default permissions if no matching entries in the auth database are found (default: "read-write") [$NTFY_AUTH_DEFAULT_ACCESS] --attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files [$NTFY_ATTACHMENT_CACHE_DIR] + --attachment-s3-url value, --attachment_s3_url value S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_S3_URL] --attachment-total-size-limit value, --attachment_total_size_limit value, -A value limit of the on-disk attachment cache (default: "5G") [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT] --attachment-file-size-limit value, --attachment_file_size_limit value, -Y value per-file attachment size limit (e.g. 300k, 2M, 100M) (default: "15M") [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT] --attachment-expiry-duration value, --attachment_expiry_duration value, -X value duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: "3h") [$NTFY_ATTACHMENT_EXPIRY_DURATION] diff --git a/docs/releases.md b/docs/releases.md index 7a40e5c4..7f5d6b45 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1761,6 +1761,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release * Support PostgreSQL read replicas for offloading non-critical read queries via `database-replica-urls` config option * Add interactive [config generator](config.md#config-generator) to the documentation to help create server configuration files +* Add S3-compatible object storage as an alternative attachment backend via `attachment-s3-url` config option **Bug fixes + maintenance:** diff --git a/go.mod b/go.mod index c073d6aa..ef8564c2 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,18 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect github.com/MicahParks/keyfunc v1.9.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 // indirect + github.com/aws/smithy-go v1.24.2 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 1c6eada9..3f373614 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,30 @@ github.com/MicahParks/keyfunc v1.9.0 h1:lhKd5xrFHLNOWrDc4Tyb/Q1AJ4LCzQ48GVJyVIID github.com/MicahParks/keyfunc v1.9.0/go.mod h1:IdnCilugA0O/99dW+/MkvlyrsX8+L8+x95xuVNtM5jw= github.com/SherClockHolmes/webpush-go v1.4.0 h1:ocnzNKWN23T9nvHi6IfyrQjkIc0oJWv1B1pULsf9i3s= github.com/SherClockHolmes/webpush-go v1.4.0/go.mod h1:XSq8pKX11vNV8MJEMwjrlTkxhAj1zKfxmyhdV7Pd6UA= +github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k= +github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 h1:SwGMTMLIlvDNyhMteQ6r8IJSBPlRdXX5d4idhIGbkXA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21/go.mod h1:UUxgWxofmOdAMuqEsSppbDtGKLfR04HGsD0HXzvhI1k= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 h1:qtJZ70afD3ISKWnoX3xB0J2otEqu3LqicRcDBqsj0hQ= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12/go.mod h1:v2pNpJbRNl4vEUWEh5ytQok0zACAKfdmKS51Hotc3pQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 h1:siU1A6xjUZ2N8zjTHSXFhB9L/2OY8Dqs0xXiLjF30jA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20/go.mod h1:4TLZCmVJDM3FOu5P5TJP0zOlu9zWgDWU7aUxWbr+rcw= +github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 h1:csi9NLpFZXb9fxY7rS1xVzgPRGMt7MSNWeQ6eo247kE= +github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1/go.mod h1:qXVal5H0ChqXP63t6jze5LmFalc7+ZE7wOdLtZ0LCP0= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/server/config.go b/server/config.go index 8ead312c..97f72a1c 100644 --- a/server/config.go +++ b/server/config.go @@ -112,6 +112,7 @@ type Config struct { AuthBcryptCost int AuthStatsQueueWriterInterval time.Duration AttachmentCacheDir string + AttachmentS3URL string AttachmentTotalSizeLimit int64 AttachmentFileSizeLimit int64 AttachmentExpiryDuration time.Duration diff --git a/server/file_cache_test.go b/server/file_cache_test.go deleted file mode 100644 index e7dee3b3..00000000 --- a/server/file_cache_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package server - -import ( - "bytes" - "fmt" - "github.com/stretchr/testify/require" - "heckel.io/ntfy/v2/util" - "os" - "strings" - "testing" -) - -var ( - oneKilobyteArray = make([]byte, 1024) -) - -func TestFileCache_Write_Success(t *testing.T) { - dir, c := newTestFileCache(t) - size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) - require.Nil(t, err) - require.Equal(t, int64(11), size) - require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) - require.Equal(t, int64(11), c.Size()) - require.Equal(t, int64(10229), c.Remaining()) -} - -func TestFileCache_Write_Remove_Success(t *testing.T) { - dir, c := newTestFileCache(t) // max = 10k (10240), each = 1k (1024) - for i := 0; i < 10; i++ { // 10x999 = 9990 - size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) - require.Nil(t, err) - require.Equal(t, int64(999), size) - } - require.Equal(t, int64(9990), c.Size()) - require.Equal(t, int64(250), c.Remaining()) - require.FileExists(t, dir+"/abcdefghijk1") - require.FileExists(t, dir+"/abcdefghijk5") - - require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5")) - require.NoFileExists(t, dir+"/abcdefghijk1") - require.NoFileExists(t, dir+"/abcdefghijk5") - require.Equal(t, int64(7992), c.Size()) - require.Equal(t, int64(2248), c.Remaining()) -} - -func TestFileCache_Write_FailedTotalSizeLimit(t *testing.T) { - dir, c := newTestFileCache(t) - for i := 0; i < 10; i++ { - size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) - require.Nil(t, err) - require.Equal(t, int64(1024), size) - } - _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) - require.Equal(t, util.ErrLimitReached, err) - require.NoFileExists(t, dir+"/abcdefghijkX") -} - -func TestFileCache_Write_FailedAdditionalLimiter(t *testing.T) { - dir, c := newTestFileCache(t) - _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) - require.Equal(t, util.ErrLimitReached, err) - require.NoFileExists(t, dir+"/abcdefghijkl") -} - -func newTestFileCache(t *testing.T) (dir string, cache *fileCache) { - dir = t.TempDir() - cache, err := newFileCache(dir, 10*1024) - require.Nil(t, err) - return dir, cache -} - -func readFile(t *testing.T, f string) string { - b, err := os.ReadFile(f) - require.Nil(t, err) - return string(b) -} diff --git a/server/log.go b/server/log.go index 03600c0d..e4ddc178 100644 --- a/server/log.go +++ b/server/log.go @@ -24,7 +24,6 @@ const ( tagSMTP = "smtp" // Receive email tagEmail = "email" // Send email tagTwilio = "twilio" - tagFileCache = "file_cache" tagMessageCache = "message_cache" tagStripe = "stripe" tagAccount = "account" diff --git a/server/server.go b/server/server.go index 24c712bd..434f93af 100644 --- a/server/server.go +++ b/server/server.go @@ -32,6 +32,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" + "heckel.io/ntfy/v2/attachment" "heckel.io/ntfy/v2/db" "heckel.io/ntfy/v2/db/pg" "heckel.io/ntfy/v2/log" @@ -64,7 +65,7 @@ type Server struct { userManager *user.Manager // Might be nil! messageCache *message.Cache // Database that stores the messages webPush *webpush.Store // Database that stores web push subscriptions - fileCache *fileCache // File system based cache that stores attachments + fileCache attachment.Store // Attachment store (file system or S3) stripe stripeAPI // Stripe API, can be replaced with a mock priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set @@ -227,12 +228,9 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } - var fileCache *fileCache - if conf.AttachmentCacheDir != "" { - fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit) - if err != nil { - return nil, err - } + fileCache, err := createAttachmentStore(conf) + if err != nil { + return nil, err } var userManager *user.Manager if conf.AuthFile != "" || pool != nil { @@ -301,6 +299,15 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) { return message.NewMemStore() } +func createAttachmentStore(conf *Config) (attachment.Store, error) { + if conf.AttachmentS3URL != "" { + return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit) + } else if conf.AttachmentCacheDir != "" { + return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit) + } + return nil, nil +} + // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts // a manager go routine to print stats and prune messages. func (s *Server) Run() error { @@ -752,7 +759,7 @@ func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor) // Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it // can associate the download bandwidth with the uploader. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error { - if s.config.AttachmentCacheDir == "" { + if s.fileCache == nil { return errHTTPInternalError } matches := fileRegex.FindStringSubmatch(r.URL.Path) @@ -760,16 +767,16 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) return errHTTPInternalErrorInvalidPath } messageID := matches[1] - file := filepath.Join(s.config.AttachmentCacheDir, messageID) - stat, err := os.Stat(file) + reader, size, err := s.fileCache.Read(messageID) if err != nil { return errHTTPNotFound.Fields(log.Context{ "message_id": messageID, - "error_context": "filesystem", + "error_context": "attachment_store", }) } + defer reader.Close() w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests - w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size())) + w.Header().Set("Content-Length", fmt.Sprintf("%d", size)) if r.Method == http.MethodHead { return nil } @@ -805,19 +812,14 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) } else if m.Sender.IsValid() { bandwidthVisitor = s.visitor(m.Sender, nil) } - if !bandwidthVisitor.BandwidthAllowed(stat.Size()) { + if !bandwidthVisitor.BandwidthAllowed(size) { return errHTTPTooManyRequestsLimitAttachmentBandwidth.With(m) } // Actually send file - f, err := os.Open(file) - if err != nil { - return err - } - defer f.Close() if m.Attachment.Name != "" { w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(m.Attachment.Name)) } - _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f) + _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), reader) return err } @@ -1408,7 +1410,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) { } func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error { - if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" { + if s.fileCache == nil || s.config.BaseURL == "" { return errHTTPBadRequestAttachmentsDisallowed.With(m) } vinfo, err := v.Info() diff --git a/server/server.yml b/server/server.yml index 43cb5fb4..e6f7afee 100644 --- a/server/server.yml +++ b/server/server.yml @@ -159,6 +159,7 @@ # - attachment-expiry-duration is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h) # # attachment-cache-dir: +# attachment-s3-url: "s3://ACCESS_KEY:SECRET_KEY@bucket/prefix?region=us-east-1" # attachment-total-size-limit: "5G" # attachment-file-size-limit: "15M" # attachment-expiry-duration: "3h" diff --git a/server/server_manager.go b/server/server_manager.go index afed7b33..5bf42924 100644 --- a/server/server_manager.go +++ b/server/server_manager.go @@ -99,6 +99,9 @@ func (s *Server) execManager() { mset(metricUsers, usersCount) mset(metricSubscribers, subscribers) mset(metricTopics, topicsCount) + if s.fileCache != nil { + mset(metricAttachmentsTotalSize, s.fileCache.Size()) + } } func (s *Server) pruneVisitors() { From b4ec6fa8df41f9ad7e7079227a9a9c9e13b02534 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 15 Mar 2026 10:12:23 -0400 Subject: [PATCH 02/32] AWS deps.. --- go.mod | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index ef8564c2..f3cd7791 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,9 @@ require github.com/pkg/errors v0.9.1 // indirect require ( firebase.google.com/go/v4 v4.19.0 github.com/SherClockHolmes/webpush-go v1.4.0 + github.com/aws/aws-sdk-go-v2 v1.41.4 + github.com/aws/aws-sdk-go-v2/credentials v1.19.12 + github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 github.com/jackc/pgx/v5 v5.8.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/prometheus/client_golang v1.23.2 @@ -52,9 +55,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect github.com/MicahParks/keyfunc v1.9.0 // indirect - github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 // indirect @@ -62,7 +63,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect - github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 // indirect github.com/aws/smithy-go v1.24.2 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect From 790ba243c766fed20f42df126dd0a69e9e662bb6 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 16 Mar 2026 09:48:26 -0400 Subject: [PATCH 03/32] S3 WIP --- attachment/store_s3.go | 186 +++------ attachment/store_s3_test.go | 298 ++++++++++++--- go.mod | 12 - go.sum | 24 -- s3/client.go | 325 ++++++++++++++++ s3/client_test.go | 727 ++++++++++++++++++++++++++++++++++++ s3/types.go | 65 ++++ s3/util.go | 161 ++++++++ s3/util_test.go | 181 +++++++++ tools/s3cli/main.go | 164 ++++++++ 10 files changed, 1917 insertions(+), 226 deletions(-) create mode 100644 s3/client.go create mode 100644 s3/client_test.go create mode 100644 s3/types.go create mode 100644 s3/util.go create mode 100644 s3/util_test.go create mode 100644 tools/s3cli/main.go diff --git a/attachment/store_s3.go b/attachment/store_s3.go index 118da4ce..5c47a81b 100644 --- a/attachment/store_s3.go +++ b/attachment/store_s3.go @@ -4,24 +4,20 @@ import ( "context" "fmt" "io" - "net/url" - "strings" + "os" "sync" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/service/s3" - s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/s3" "heckel.io/ntfy/v2/util" ) -const tagS3Store = "s3_store" +const ( + tagS3Store = "s3_store" +) type s3Store struct { client *s3.Client - bucket string - prefix string totalSizeCurrent int64 totalSizeLimit int64 mu sync.Mutex @@ -31,14 +27,12 @@ type s3Store struct { // // s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) { - bucket, prefix, client, err := parseS3URL(s3URL) + cfg, err := s3.ParseURL(s3URL) if err != nil { return nil, err } store := &s3Store{ - client: client, - bucket: bucket, - prefix: prefix, + client: s3.New(cfg), totalSizeLimit: totalSizeLimit, } if totalSizeLimit > 0 { @@ -51,98 +45,40 @@ func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) { return store, nil } -func parseS3URL(s3URL string) (bucket string, prefix string, client *s3.Client, err error) { - u, err := url.Parse(s3URL) - if err != nil { - return "", "", nil, fmt.Errorf("s3 store: invalid URL: %w", err) - } - if u.Scheme != "s3" { - return "", "", nil, fmt.Errorf("s3 store: URL scheme must be 's3', got '%s'", u.Scheme) - } - if u.Host == "" { - return "", "", nil, fmt.Errorf("s3 store: bucket name must be specified as host") - } - bucket = u.Host - prefix = strings.TrimPrefix(u.Path, "/") - - accessKey := u.User.Username() - secretKey, _ := u.User.Password() - if accessKey == "" || secretKey == "" { - return "", "", nil, fmt.Errorf("s3 store: access key and secret key must be specified in URL") - } - - region := u.Query().Get("region") - if region == "" { - return "", "", nil, fmt.Errorf("s3 store: region query parameter is required") - } - endpoint := u.Query().Get("endpoint") - - cfg := aws.Config{ - Region: region, - Credentials: credentials.NewStaticCredentialsProvider(accessKey, secretKey, ""), - } - var opts []func(*s3.Options) - if endpoint != "" { - opts = append(opts, func(o *s3.Options) { - o.BaseEndpoint = aws.String(endpoint) - o.UsePathStyle = true - }) - } - client = s3.NewFromConfig(cfg, opts...) - return bucket, prefix, client, nil -} - -func (c *s3Store) objectKey(id string) string { - if c.prefix != "" { - return c.prefix + "/" + id - } - return id -} - func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { if !fileIDRegex.MatchString(id) { return 0, errInvalidFileID } log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3") - // Use io.Pipe so we can apply limiters while streaming to S3 - pr, pw := io.Pipe() - var writeErr error - var size int64 - + // Write through limiters into a temp file. This avoids buffering the full attachment in + // memory while still giving us the Content-Length that PutObject requires. limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - go func() { - limitWriter := util.NewLimitWriter(pw, limiters...) - size, writeErr = io.Copy(limitWriter, in) - if writeErr != nil { - pw.CloseWithError(writeErr) - } else { - pw.Close() - } - }() - - key := c.objectKey(id) - _, err := c.client.PutObject(context.Background(), &s3.PutObjectInput{ - Bucket: aws.String(c.bucket), - Key: aws.String(key), - Body: pr, - }) + tmpFile, err := os.CreateTemp("", "ntfy-s3-upload-*") if err != nil { - // If the limiter caused the error, return the original write error - if writeErr != nil { - return 0, writeErr - } - return 0, fmt.Errorf("s3 store: PutObject failed: %w", err) + return 0, fmt.Errorf("s3 store: failed to create temp file: %w", err) } - if writeErr != nil { - // The write goroutine failed but PutObject somehow succeeded; clean up - _, _ = c.client.DeleteObject(context.Background(), &s3.DeleteObjectInput{ - Bucket: aws.String(c.bucket), - Key: aws.String(key), - }) - return 0, writeErr + tmpPath := tmpFile.Name() + defer os.Remove(tmpPath) + limitWriter := util.NewLimitWriter(tmpFile, limiters...) + size, err := io.Copy(limitWriter, in) + if err != nil { + tmpFile.Close() + return 0, err + } + if err := tmpFile.Close(); err != nil { + return 0, err } + // Re-open the temp file for reading and stream it to S3 + f, err := os.Open(tmpPath) + if err != nil { + return 0, err + } + defer f.Close() + if err := c.client.PutObject(context.Background(), id, f, size); err != nil { + return 0, err + } c.mu.Lock() c.totalSizeCurrent += size c.mu.Unlock() @@ -153,19 +89,7 @@ func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) { if !fileIDRegex.MatchString(id) { return nil, 0, errInvalidFileID } - key := c.objectKey(id) - resp, err := c.client.GetObject(context.Background(), &s3.GetObjectInput{ - Bucket: aws.String(c.bucket), - Key: aws.String(key), - }) - if err != nil { - return nil, 0, fmt.Errorf("s3 store: GetObject failed: %w", err) - } - var size int64 - if resp.ContentLength != nil { - size = *resp.ContentLength - } - return resp.Body, size, nil + return c.client.GetObject(context.Background(), id) } func (c *s3Store) Remove(ids ...string) error { @@ -181,23 +105,11 @@ func (c *s3Store) Remove(ids ...string) error { end = len(ids) } batch := ids[i:end] - objects := make([]s3types.ObjectIdentifier, len(batch)) - for j, id := range batch { + for _, id := range batch { log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3") - key := c.objectKey(id) - objects[j] = s3types.ObjectIdentifier{ - Key: aws.String(key), - } } - _, err := c.client.DeleteObjects(context.Background(), &s3.DeleteObjectsInput{ - Bucket: aws.String(c.bucket), - Delete: &s3types.Delete{ - Objects: objects, - Quiet: aws.Bool(true), - }, - }) - if err != nil { - return fmt.Errorf("s3 store: DeleteObjects failed: %w", err) + if err := c.client.DeleteObjects(context.Background(), batch); err != nil { + return err } } // Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern) @@ -227,29 +139,15 @@ func (c *s3Store) Remaining() int64 { return remaining } +// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix. func (c *s3Store) computeSize() (int64, error) { - var size int64 - paginator := s3.NewListObjectsV2Paginator(c.client, &s3.ListObjectsV2Input{ - Bucket: aws.String(c.bucket), - Prefix: aws.String(c.prefixForList()), - }) - for paginator.HasMorePages() { - page, err := paginator.NextPage(context.Background()) - if err != nil { - return 0, err - } - for _, obj := range page.Contents { - if obj.Size != nil { - size += *obj.Size - } - } + objects, err := c.client.ListAllObjects(context.Background()) + if err != nil { + return 0, err } - return size, nil -} - -func (c *s3Store) prefixForList() string { - if c.prefix != "" { - return c.prefix + "/" + var totalSize int64 + for _, obj := range objects { + totalSize += obj.Size } - return "" + return totalSize, nil } diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index a1808d0c..c898244d 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -1,76 +1,282 @@ package attachment import ( + "bytes" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" "testing" "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/s3" + "heckel.io/ntfy/v2/util" ) -func TestParseS3URL_Success(t *testing.T) { - bucket, prefix, client, err := parseS3URL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1") +// --- Integration tests using a mock S3 server --- + +func TestS3Store_WriteReadRemove(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + + // Write + size, err := store.Write("abcdefghijkl", strings.NewReader("hello world")) require.Nil(t, err) - require.Equal(t, "my-bucket", bucket) - require.Equal(t, "attachments", prefix) - require.NotNil(t, client) -} + require.Equal(t, int64(11), size) + require.Equal(t, int64(11), store.Size()) -func TestParseS3URL_NoPrefix(t *testing.T) { - bucket, prefix, client, err := parseS3URL("s3://AKID:SECRET@my-bucket?region=us-east-1") + // Read back + reader, readSize, err := store.Read("abcdefghijkl") require.Nil(t, err) - require.Equal(t, "my-bucket", bucket) - require.Equal(t, "", prefix) - require.NotNil(t, client) -} - -func TestParseS3URL_WithEndpoint(t *testing.T) { - bucket, prefix, client, err := parseS3URL("s3://AKID:SECRET@my-bucket/prefix?region=us-east-1&endpoint=https://s3.example.com") + require.Equal(t, int64(11), readSize) + data, err := io.ReadAll(reader) + reader.Close() require.Nil(t, err) - require.Equal(t, "my-bucket", bucket) - require.Equal(t, "prefix", prefix) - require.NotNil(t, client) + require.Equal(t, "hello world", string(data)) + + // Remove + require.Nil(t, store.Remove("abcdefghijkl")) + require.Equal(t, int64(0), store.Size()) + + // Read after remove should fail + _, _, err = store.Read("abcdefghijkl") + require.Error(t, err) } -func TestParseS3URL_NestedPrefix(t *testing.T) { - bucket, prefix, _, err := parseS3URL("s3://AKID:SECRET@my-bucket/a/b/c?region=us-east-1") +func TestS3Store_WriteNoPrefix(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + store := newTestS3Store(t, server, "my-bucket", "", 10*1024) + + size, err := store.Write("abcdefghijkl", strings.NewReader("test")) require.Nil(t, err) - require.Equal(t, "my-bucket", bucket) - require.Equal(t, "a/b/c", prefix) + require.Equal(t, int64(4), size) + + reader, _, err := store.Read("abcdefghijkl") + require.Nil(t, err) + data, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "test", string(data)) } -func TestParseS3URL_MissingRegion(t *testing.T) { - _, _, _, err := parseS3URL("s3://AKID:SECRET@my-bucket") +func TestS3Store_WriteTotalSizeLimit(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + store := newTestS3Store(t, server, "my-bucket", "pfx", 100) + + // First write fits + _, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80))) + require.Nil(t, err) + require.Equal(t, int64(80), store.Size()) + require.Equal(t, int64(20), store.Remaining()) + + // Second write exceeds total limit + _, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50))) + require.Equal(t, util.ErrLimitReached, err) +} + +func TestS3Store_WriteFileSizeLimit(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + + _, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100)) + require.Equal(t, util.ErrLimitReached, err) +} + +func TestS3Store_WriteRemoveMultiple(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + + for i := 0; i < 5; i++ { + _, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100))) + require.Nil(t, err) + } + require.Equal(t, int64(500), store.Size()) + + require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3")) + require.Equal(t, int64(300), store.Size()) +} + +func TestS3Store_ReadNotFound(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + + _, _, err := store.Read("abcdefghijkl") require.Error(t, err) - require.Contains(t, err.Error(), "region") } -func TestParseS3URL_MissingCredentials(t *testing.T) { - _, _, _, err := parseS3URL("s3://my-bucket?region=us-east-1") - require.Error(t, err) - require.Contains(t, err.Error(), "access key") +func TestS3Store_InvalidID(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + + _, err := store.Write("bad", strings.NewReader("x")) + require.Equal(t, errInvalidFileID, err) + + _, _, err = store.Read("bad") + require.Equal(t, errInvalidFileID, err) + + err = store.Remove("bad") + require.Equal(t, errInvalidFileID, err) } -func TestParseS3URL_MissingSecretKey(t *testing.T) { - _, _, _, err := parseS3URL("s3://AKID@my-bucket?region=us-east-1") - require.Error(t, err) - require.Contains(t, err.Error(), "secret key") +// --- Helpers --- + +func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store { + t.Helper() + // httptest.NewTLSServer URL is like "https://127.0.0.1:PORT" + host := strings.TrimPrefix(server.URL, "https://") + s := &s3Store{ + client: &s3.Client{ + AccessKey: "AKID", + SecretKey: "SECRET", + Region: "us-east-1", + Endpoint: host, + Bucket: bucket, + Prefix: prefix, + PathStyle: true, + HTTPClient: server.Client(), + }, + totalSizeLimit: totalSizeLimit, + } + // Compute initial size (should be 0 for fresh mock) + size, err := s.computeSize() + require.Nil(t, err) + s.totalSizeCurrent = size + return s } -func TestParseS3URL_WrongScheme(t *testing.T) { - _, _, _, err := parseS3URL("http://AKID:SECRET@my-bucket?region=us-east-1") - require.Error(t, err) - require.Contains(t, err.Error(), "scheme") +// --- Mock S3 server --- +// +// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and +// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. + +type mockS3Server struct { + objects map[string][]byte // full key (bucket/key) -> body + mu sync.RWMutex } -func TestParseS3URL_EmptyBucket(t *testing.T) { - _, _, _, err := parseS3URL("s3://AKID:SECRET@?region=us-east-1") - require.Error(t, err) - require.Contains(t, err.Error(), "bucket") +func newMockS3Server() *httptest.Server { + m := &mockS3Server{objects: make(map[string][]byte)} + return httptest.NewTLSServer(m) } -func TestS3Store_ObjectKey(t *testing.T) { - s := &s3Store{prefix: "attachments"} - require.Equal(t, "attachments/abcdefghijkl", s.objectKey("abcdefghijkl")) +func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Path is /{bucket}[/{key...}] + path := strings.TrimPrefix(r.URL.Path, "/") - s2 := &s3Store{prefix: ""} - require.Equal(t, "abcdefghijkl", s2.objectKey("abcdefghijkl")) + switch { + case r.Method == http.MethodPut: + m.handlePut(w, r, path) + case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2": + m.handleList(w, r, path) + case r.Method == http.MethodGet: + m.handleGet(w, r, path) + case r.Method == http.MethodPost && r.URL.Query().Has("delete"): + m.handleDelete(w, r, path) + default: + http.Error(w, "not implemented", http.StatusNotImplemented) + } +} + +func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + m.mu.Lock() + m.objects[path] = body + m.mu.Unlock() + w.WriteHeader(http.StatusOK) +} + +func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { + m.mu.RLock() + body, ok := m.objects[path] + m.mu.RUnlock() + if !ok { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`NoSuchKeyThe specified key does not exist.`)) + return + } + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) { + // bucketPath is just the bucket name + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + var req struct { + Objects []struct { + Key string `xml:"Key"` + } `xml:"Object"` + } + if err := xml.Unmarshal(body, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + m.mu.Lock() + for _, obj := range req.Objects { + delete(m.objects, bucketPath+"/"+obj.Key) + } + m.mu.Unlock() + w.WriteHeader(http.StatusOK) + w.Write([]byte(``)) +} + +func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) { + prefix := r.URL.Query().Get("prefix") + m.mu.RLock() + var contents []s3ListObject + for key, body := range m.objects { + // key is "bucket/objectkey", strip bucket prefix + objKey := strings.TrimPrefix(key, bucketPath+"/") + if objKey == key { + continue // different bucket + } + if prefix == "" || strings.HasPrefix(objKey, prefix) { + contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))}) + } + } + m.mu.RUnlock() + + resp := s3ListResponse{ + Contents: contents, + IsTruncated: false, + } + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + xml.NewEncoder(w).Encode(resp) +} + +type s3ListResponse struct { + XMLName xml.Name `xml:"ListBucketResult"` + Contents []s3ListObject `xml:"Contents"` + IsTruncated bool `xml:"IsTruncated"` +} + +type s3ListObject struct { + Key string `xml:"Key"` + Size int64 `xml:"Size"` } diff --git a/go.mod b/go.mod index f3cd7791..c073d6aa 100644 --- a/go.mod +++ b/go.mod @@ -30,9 +30,6 @@ require github.com/pkg/errors v0.9.1 // indirect require ( firebase.google.com/go/v4 v4.19.0 github.com/SherClockHolmes/webpush-go v1.4.0 - github.com/aws/aws-sdk-go-v2 v1.41.4 - github.com/aws/aws-sdk-go-v2/credentials v1.19.12 - github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 github.com/jackc/pgx/v5 v5.8.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/prometheus/client_golang v1.23.2 @@ -55,15 +52,6 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect github.com/MicahParks/keyfunc v1.9.0 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 // indirect - github.com/aws/smithy-go v1.24.2 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 3f373614..1c6eada9 100644 --- a/go.sum +++ b/go.sum @@ -40,30 +40,6 @@ github.com/MicahParks/keyfunc v1.9.0 h1:lhKd5xrFHLNOWrDc4Tyb/Q1AJ4LCzQ48GVJyVIID github.com/MicahParks/keyfunc v1.9.0/go.mod h1:IdnCilugA0O/99dW+/MkvlyrsX8+L8+x95xuVNtM5jw= github.com/SherClockHolmes/webpush-go v1.4.0 h1:ocnzNKWN23T9nvHi6IfyrQjkIc0oJWv1B1pULsf9i3s= github.com/SherClockHolmes/webpush-go v1.4.0/go.mod h1:XSq8pKX11vNV8MJEMwjrlTkxhAj1zKfxmyhdV7Pd6UA= -github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k= -github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= -github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8= -github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 h1:SwGMTMLIlvDNyhMteQ6r8IJSBPlRdXX5d4idhIGbkXA= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21/go.mod h1:UUxgWxofmOdAMuqEsSppbDtGKLfR04HGsD0HXzvhI1k= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12 h1:qtJZ70afD3ISKWnoX3xB0J2otEqu3LqicRcDBqsj0hQ= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.12/go.mod h1:v2pNpJbRNl4vEUWEh5ytQok0zACAKfdmKS51Hotc3pQ= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20 h1:siU1A6xjUZ2N8zjTHSXFhB9L/2OY8Dqs0xXiLjF30jA= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.20/go.mod h1:4TLZCmVJDM3FOu5P5TJP0zOlu9zWgDWU7aUxWbr+rcw= -github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 h1:csi9NLpFZXb9fxY7rS1xVzgPRGMt7MSNWeQ6eo247kE= -github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1/go.mod h1:qXVal5H0ChqXP63t6jze5LmFalc7+ZE7wOdLtZ0LCP0= -github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= -github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/s3/client.go b/s3/client.go new file mode 100644 index 00000000..7fdd8093 --- /dev/null +++ b/s3/client.go @@ -0,0 +1,325 @@ +// Package s3 provides a minimal S3-compatible client that works with AWS S3, DigitalOcean Spaces, +// GCP Cloud Storage, MinIO, Backblaze B2, and other S3-compatible providers. It uses raw HTTP +// requests with AWS Signature V4 signing, no AWS SDK dependency required. +package s3 + +import ( + "bytes" + "context" + "crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers + "encoding/base64" + "encoding/hex" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "time" +) + +// Client is a minimal S3-compatible client. It supports PutObject, GetObject, DeleteObjects, +// and ListObjectsV2 operations using AWS Signature V4 signing. The bucket and optional key prefix +// are fixed at construction time. All operations target the same bucket and prefix. +// +// Fields must not be modified after the Client is passed to any method or goroutine. +type Client struct { + AccessKey string // AWS access key ID + SecretKey string // AWS secret access key + Region string // e.g. "us-east-1" + Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com" + Bucket string // S3 bucket name + Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically + PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style + HTTPClient *http.Client // if nil, http.DefaultClient is used +} + +// New creates a new S3 client from the given Config. +func New(config *Config) *Client { + return &Client{ + AccessKey: config.AccessKey, + SecretKey: config.SecretKey, + Region: config.Region, + Endpoint: config.Endpoint, + Bucket: config.Bucket, + Prefix: config.Prefix, + PathStyle: config.PathStyle, + } +} + +// PutObject uploads body to the given key. The key is automatically prefixed with the client's +// configured prefix. The body size must be known in advance. The payload is sent as +// UNSIGNED-PAYLOAD, which is supported by all major S3-compatible providers over HTTPS. +func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, size int64) error { + fullKey := c.objectKey(key) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body) + if err != nil { + return fmt.Errorf("s3: PutObject request: %w", err) + } + req.ContentLength = size + c.signV4(req, unsignedPayload) + resp, err := c.httpClient().Do(req) + if err != nil { + return fmt.Errorf("s3: PutObject: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return parseError(resp) + } + return nil +} + +// GetObject downloads an object. The key is automatically prefixed with the client's configured +// prefix. The caller must close the returned ReadCloser. +func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) { + fullKey := c.objectKey(key) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(fullKey), nil) + if err != nil { + return nil, 0, fmt.Errorf("s3: GetObject request: %w", err) + } + c.signV4(req, emptyPayloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return nil, 0, fmt.Errorf("s3: GetObject: %w", err) + } + if resp.StatusCode/100 != 2 { + err := parseError(resp) + resp.Body.Close() + return nil, 0, err + } + return resp.Body, resp.ContentLength, nil +} + +// DeleteObjects removes multiple objects in a single batch request. Keys are automatically +// prefixed with the client's configured prefix. S3 supports up to 1000 keys per call; the +// caller is responsible for batching if needed. +// +// Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present +// in the response, they are returned as a combined error. +func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { + var body bytes.Buffer + body.WriteString("true") + for _, key := range keys { + body.WriteString("") + xml.EscapeText(&body, []byte(c.objectKey(key))) + body.WriteString("") + } + body.WriteString("") + bodyBytes := body.Bytes() + payloadHash := sha256Hex(bodyBytes) + + // Content-MD5 is required by the S3 protocol for DeleteObjects requests. + md5Sum := md5.Sum(bodyBytes) //nolint:gosec + contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:]) + + reqURL := c.bucketURL() + "?delete=" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes)) + if err != nil { + return fmt.Errorf("s3: DeleteObjects request: %w", err) + } + req.ContentLength = int64(len(bodyBytes)) + req.Header.Set("Content-Type", "application/xml") + req.Header.Set("Content-MD5", contentMD5) + c.signV4(req, payloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return fmt.Errorf("s3: DeleteObjects: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return parseError(resp) + } + + // S3 may return HTTP 200 with per-key errors in the response body + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return fmt.Errorf("s3: DeleteObjects read response: %w", err) + } + var result deleteResult + if err := xml.Unmarshal(respBody, &result); err != nil { + return nil // If we can't parse, assume success (Quiet mode returns empty body on success) + } + if len(result.Errors) > 0 { + var msgs []string + for _, e := range result.Errors { + msgs = append(msgs, fmt.Sprintf("%s: %s", e.Key, e.Message)) + } + return fmt.Errorf("s3: DeleteObjects partial failure: %s", strings.Join(msgs, "; ")) + } + return nil +} + +// ListObjects performs a single ListObjectsV2 request using the client's configured prefix. +// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000). +func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) { + query := url.Values{"list-type": {"2"}} + if prefix := c.prefixForList(); prefix != "" { + query.Set("prefix", prefix) + } + if continuationToken != "" { + query.Set("continuation-token", continuationToken) + } + if maxKeys > 0 { + query.Set("max-keys", strconv.Itoa(maxKeys)) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil) + if err != nil { + return nil, fmt.Errorf("s3: ListObjects request: %w", err) + } + c.signV4(req, emptyPayloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return nil, fmt.Errorf("s3: ListObjects: %w", err) + } + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("s3: ListObjects read: %w", err) + } + if resp.StatusCode/100 != 2 { + return nil, parseErrorFromBytes(resp.StatusCode, respBody) + } + var result listObjectsV2Response + if err := xml.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("s3: ListObjects XML: %w", err) + } + objects := make([]Object, len(result.Contents)) + for i, obj := range result.Contents { + objects[i] = Object(obj) + } + return &ListResult{ + Objects: objects, + IsTruncated: result.IsTruncated, + NextContinuationToken: result.NextContinuationToken, + }, nil +} + +// ListAllObjects returns all objects under the client's configured prefix by paginating through +// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve. +func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { + const maxPages = 10000 + var all []Object + var token string + for page := 0; page < maxPages; page++ { + result, err := c.ListObjects(ctx, token, 0) + if err != nil { + return nil, err + } + all = append(all, result.Objects...) + if !result.IsTruncated { + return all, nil + } + token = result.NextContinuationToken + } + return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) +} + +// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 +// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. +func (c *Client) signV4(req *http.Request, payloadHash string) { + now := time.Now().UTC() + datestamp := now.Format("20060102") + amzDate := now.Format("20060102T150405Z") + + // Required headers + req.Header.Set("Host", c.hostHeader()) + req.Header.Set("X-Amz-Date", amzDate) + req.Header.Set("X-Amz-Content-Sha256", payloadHash) + + // Canonical headers (all headers we set, sorted by lowercase key) + signedKeys := make([]string, 0, len(req.Header)) + canonHeaders := make(map[string]string, len(req.Header)) + for k := range req.Header { + lk := strings.ToLower(k) + signedKeys = append(signedKeys, lk) + canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k)) + } + sort.Strings(signedKeys) + signedHeadersStr := strings.Join(signedKeys, ";") + var chBuf strings.Builder + for _, k := range signedKeys { + chBuf.WriteString(k) + chBuf.WriteByte(':') + chBuf.WriteString(canonHeaders[k]) + chBuf.WriteByte('\n') + } + + // Canonical request + canonicalRequest := strings.Join([]string{ + req.Method, + canonicalURI(req.URL), + canonicalQueryString(req.URL.Query()), + chBuf.String(), + signedHeadersStr, + payloadHash, + }, "\n") + + // String to sign + credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request" + stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest)) + + // Signing key + signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256( + []byte("AWS4"+c.SecretKey), []byte(datestamp)), + []byte(c.Region)), + []byte("s3")), + []byte("aws4_request")) + + signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) + req.Header.Set("Authorization", fmt.Sprintf( + "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", + c.AccessKey, credentialScope, signedHeadersStr, signature, + )) +} + +func (c *Client) httpClient() *http.Client { + if c.HTTPClient != nil { + return c.HTTPClient + } + return http.DefaultClient +} + +// objectKey prepends the configured prefix to the given key. +func (c *Client) objectKey(key string) string { + if c.Prefix != "" { + return c.Prefix + "/" + key + } + return key +} + +// prefixForList returns the prefix to use in ListObjectsV2 requests, +// with a trailing slash so that only objects under the prefix directory are returned. +func (c *Client) prefixForList() string { + if c.Prefix != "" { + return c.Prefix + "/" + } + return "" +} + +// bucketURL returns the base URL for bucket-level operations. +func (c *Client) bucketURL() string { + if c.PathStyle { + return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket) + } + return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint) +} + +// objectURL returns the full URL for an object (key should already include the prefix). +// Each path segment is URI-encoded to handle special characters in keys. +func (c *Client) objectURL(key string) string { + segments := strings.Split(key, "/") + for i, seg := range segments { + segments[i] = uriEncode(seg) + } + return c.bucketURL() + "/" + strings.Join(segments, "/") +} + +// hostHeader returns the value for the Host header. +func (c *Client) hostHeader() string { + if c.PathStyle { + return c.Endpoint + } + return c.Bucket + "." + c.Endpoint +} diff --git a/s3/client_test.go b/s3/client_test.go new file mode 100644 index 00000000..f4a10213 --- /dev/null +++ b/s3/client_test.go @@ -0,0 +1,727 @@ +package s3 + +import ( + "bytes" + "context" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "sort" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +// --- Mock S3 server --- +// +// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and +// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. + +type mockS3Server struct { + objects map[string][]byte // full key (bucket/key) -> body + mu sync.RWMutex +} + +func newMockS3Server() (*httptest.Server, *mockS3Server) { + m := &mockS3Server{objects: make(map[string][]byte)} + return httptest.NewTLSServer(m), m +} + +func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Path is /{bucket}[/{key...}] + path := strings.TrimPrefix(r.URL.Path, "/") + + switch { + case r.Method == http.MethodPut: + m.handlePut(w, r, path) + case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2": + m.handleList(w, r, path) + case r.Method == http.MethodGet: + m.handleGet(w, r, path) + case r.Method == http.MethodPost && r.URL.Query().Has("delete"): + m.handleDelete(w, r, path) + default: + http.Error(w, "not implemented", http.StatusNotImplemented) + } +} + +func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + m.mu.Lock() + m.objects[path] = body + m.mu.Unlock() + w.WriteHeader(http.StatusOK) +} + +func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { + m.mu.RLock() + body, ok := m.objects[path] + m.mu.RUnlock() + if !ok { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`NoSuchKeyThe specified key does not exist.`)) + return + } + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) + w.WriteHeader(http.StatusOK) + w.Write(body) +} + +type listObjectsResponse struct { + XMLName xml.Name `xml:"ListBucketResult"` + Contents []listObject `xml:"Contents"` + // Pagination support + IsTruncated bool `xml:"IsTruncated"` + NextContinuationToken string `xml:"NextContinuationToken"` +} + +func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) { + // bucketPath is just the bucket name + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + var req struct { + Objects []struct { + Key string `xml:"Key"` + } `xml:"Object"` + } + if err := xml.Unmarshal(body, &req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + m.mu.Lock() + for _, obj := range req.Objects { + delete(m.objects, bucketPath+"/"+obj.Key) + } + m.mu.Unlock() + w.WriteHeader(http.StatusOK) + w.Write([]byte(``)) +} + +func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) { + prefix := r.URL.Query().Get("prefix") + contToken := r.URL.Query().Get("continuation-token") + + m.mu.RLock() + var allKeys []string + for key := range m.objects { + objKey := strings.TrimPrefix(key, bucketPath+"/") + if objKey == key { + continue // different bucket + } + if prefix == "" || strings.HasPrefix(objKey, prefix) { + allKeys = append(allKeys, objKey) + } + } + m.mu.RUnlock() + sort.Strings(allKeys) + + // Simple continuation token: it's the key to start after + startIdx := 0 + if contToken != "" { + for i, k := range allKeys { + if k == contToken { + startIdx = i + 1 + break + } + } + } + + maxKeys := 1000 + if mk := r.URL.Query().Get("max-keys"); mk != "" { + fmt.Sscanf(mk, "%d", &maxKeys) + } + + endIdx := startIdx + maxKeys + truncated := false + nextToken := "" + if endIdx < len(allKeys) { + truncated = true + nextToken = allKeys[endIdx-1] + allKeys = allKeys[startIdx:endIdx] + } else { + allKeys = allKeys[startIdx:] + } + + m.mu.RLock() + var contents []listObject + for _, objKey := range allKeys { + body := m.objects[bucketPath+"/"+objKey] + contents = append(contents, listObject{Key: objKey, Size: int64(len(body))}) + } + m.mu.RUnlock() + + resp := listObjectsResponse{ + Contents: contents, + IsTruncated: truncated, + NextContinuationToken: nextToken, + } + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + xml.NewEncoder(w).Encode(resp) +} + +func (m *mockS3Server) objectCount() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.objects) +} + +// --- Helper to create a test client pointing at mock server --- + +func newTestClient(server *httptest.Server, bucket, prefix string) *Client { + // httptest.NewTLSServer URL is like "https://127.0.0.1:PORT" + host := strings.TrimPrefix(server.URL, "https://") + return &Client{ + AccessKey: "AKID", + SecretKey: "SECRET", + Region: "us-east-1", + Endpoint: host, + Bucket: bucket, + Prefix: prefix, + PathStyle: true, + HTTPClient: server.Client(), + } +} + +// --- URL parsing tests --- + +func TestParseURL_Success(t *testing.T) { + cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1") + require.Nil(t, err) + require.Equal(t, "my-bucket", cfg.Bucket) + require.Equal(t, "attachments", cfg.Prefix) + require.Equal(t, "us-east-1", cfg.Region) + require.Equal(t, "AKID", cfg.AccessKey) + require.Equal(t, "SECRET", cfg.SecretKey) + require.Equal(t, "s3.us-east-1.amazonaws.com", cfg.Endpoint) + require.False(t, cfg.PathStyle) +} + +func TestParseURL_NoPrefix(t *testing.T) { + cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1") + require.Nil(t, err) + require.Equal(t, "my-bucket", cfg.Bucket) + require.Equal(t, "", cfg.Prefix) +} + +func TestParseURL_WithEndpoint(t *testing.T) { + cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/prefix?region=us-east-1&endpoint=https://s3.example.com") + require.Nil(t, err) + require.Equal(t, "my-bucket", cfg.Bucket) + require.Equal(t, "prefix", cfg.Prefix) + require.Equal(t, "s3.example.com", cfg.Endpoint) + require.True(t, cfg.PathStyle) +} + +func TestParseURL_EndpointHTTP(t *testing.T) { + cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=http://localhost:9000") + require.Nil(t, err) + require.Equal(t, "localhost:9000", cfg.Endpoint) + require.True(t, cfg.PathStyle) +} + +func TestParseURL_EndpointTrailingSlash(t *testing.T) { + cfg, err := ParseURL("s3://AKID:SECRET@my-bucket?region=us-east-1&endpoint=https://s3.example.com/") + require.Nil(t, err) + require.Equal(t, "s3.example.com", cfg.Endpoint) +} + +func TestParseURL_NestedPrefix(t *testing.T) { + cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/a/b/c?region=us-east-1") + require.Nil(t, err) + require.Equal(t, "my-bucket", cfg.Bucket) + require.Equal(t, "a/b/c", cfg.Prefix) +} + +func TestParseURL_MissingRegion(t *testing.T) { + _, err := ParseURL("s3://AKID:SECRET@my-bucket") + require.Error(t, err) + require.Contains(t, err.Error(), "region") +} + +func TestParseURL_MissingCredentials(t *testing.T) { + _, err := ParseURL("s3://my-bucket?region=us-east-1") + require.Error(t, err) + require.Contains(t, err.Error(), "access key") +} + +func TestParseURL_MissingSecretKey(t *testing.T) { + _, err := ParseURL("s3://AKID@my-bucket?region=us-east-1") + require.Error(t, err) + require.Contains(t, err.Error(), "secret key") +} + +func TestParseURL_WrongScheme(t *testing.T) { + _, err := ParseURL("http://AKID:SECRET@my-bucket?region=us-east-1") + require.Error(t, err) + require.Contains(t, err.Error(), "scheme") +} + +func TestParseURL_EmptyBucket(t *testing.T) { + _, err := ParseURL("s3://AKID:SECRET@?region=us-east-1") + require.Error(t, err) + require.Contains(t, err.Error(), "bucket") +} + +// --- Unit tests: URL construction --- + +func TestClient_BucketURL_PathStyle(t *testing.T) { + c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL()) +} + +func TestClient_BucketURL_VirtualHosted(t *testing.T) { + c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL()) +} + +func TestClient_ObjectURL_PathStyle(t *testing.T) { + c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj")) +} + +func TestClient_ObjectURL_VirtualHosted(t *testing.T) { + c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj")) +} + +func TestClient_HostHeader_PathStyle(t *testing.T) { + c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + require.Equal(t, "s3.example.com", c.hostHeader()) +} + +func TestClient_HostHeader_VirtualHosted(t *testing.T) { + c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader()) +} + +func TestClient_ObjectKey(t *testing.T) { + c := &Client{Prefix: "attachments"} + require.Equal(t, "attachments/file123", c.objectKey("file123")) + + c2 := &Client{Prefix: ""} + require.Equal(t, "file123", c2.objectKey("file123")) +} + +func TestClient_PrefixForList(t *testing.T) { + c := &Client{Prefix: "attachments"} + require.Equal(t, "attachments/", c.prefixForList()) + + c2 := &Client{Prefix: ""} + require.Equal(t, "", c2.prefixForList()) +} + +// --- Integration tests using mock S3 server --- + +func TestClient_PutGetObject(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + // Put + err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 11) + require.Nil(t, err) + + // Get + reader, size, err := client.GetObject(ctx, "test-key") + require.Nil(t, err) + require.Equal(t, int64(11), size) + data, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "hello world", string(data)) +} + +func TestClient_PutGetObject_WithPrefix(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "pfx") + + ctx := context.Background() + + err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 5) + require.Nil(t, err) + + reader, _, err := client.GetObject(ctx, "test-key") + require.Nil(t, err) + data, _ := io.ReadAll(reader) + reader.Close() + require.Equal(t, "hello", string(data)) +} + +func TestClient_GetObject_NotFound(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + _, _, err := client.GetObject(context.Background(), "nonexistent") + require.Error(t, err) + var errResp *ErrorResponse + require.ErrorAs(t, err, &errResp) + require.Equal(t, 404, errResp.StatusCode) + require.Equal(t, "NoSuchKey", errResp.Code) +} + +func TestClient_DeleteObjects(t *testing.T) { + server, mock := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + // Put several objects + for i := 0; i < 5; i++ { + err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 4) + require.Nil(t, err) + } + require.Equal(t, 5, mock.objectCount()) + + // Delete some + err := client.DeleteObjects(ctx, []string{"key-1", "key-3"}) + require.Nil(t, err) + require.Equal(t, 3, mock.objectCount()) + + // Verify deleted ones are gone + _, _, err = client.GetObject(ctx, "key-1") + require.Error(t, err) + _, _, err = client.GetObject(ctx, "key-3") + require.Error(t, err) + + // Verify remaining ones are still there + reader, _, err := client.GetObject(ctx, "key-0") + require.Nil(t, err) + reader.Close() +} + +func TestClient_ListObjects(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + + ctx := context.Background() + + // Client with prefix "pfx": list should only return objects under pfx/ + client := newTestClient(server, "my-bucket", "pfx") + for i := 0; i < 3; i++ { + err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 1) + require.Nil(t, err) + } + + // Also put an object outside the prefix using a no-prefix client + clientNoPrefix := newTestClient(server, "my-bucket", "") + err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 1) + require.Nil(t, err) + + // List with prefix client: should only see 3 + result, err := client.ListObjects(ctx, "", 0) + require.Nil(t, err) + require.Len(t, result.Objects, 3) + require.False(t, result.IsTruncated) + + // List with no-prefix client: should see all 4 + result, err = clientNoPrefix.ListObjects(ctx, "", 0) + require.Nil(t, err) + require.Len(t, result.Objects, 4) +} + +func TestClient_ListObjects_Pagination(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + // Put 5 objects + for i := 0; i < 5; i++ { + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1) + require.Nil(t, err) + } + + // List with max-keys=2 + result, err := client.ListObjects(ctx, "", 2) + require.Nil(t, err) + require.Len(t, result.Objects, 2) + require.True(t, result.IsTruncated) + require.NotEmpty(t, result.NextContinuationToken) + + // Get next page + result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2) + require.Nil(t, err) + require.Len(t, result2.Objects, 2) + require.True(t, result2.IsTruncated) + + // Get last page + result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2) + require.Nil(t, err) + require.Len(t, result3.Objects, 1) + require.False(t, result3.IsTruncated) +} + +func TestClient_ListAllObjects(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "pfx") + + ctx := context.Background() + + for i := 0; i < 10; i++ { + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1) + require.Nil(t, err) + } + + objects, err := client.ListAllObjects(ctx) + require.Nil(t, err) + require.Len(t, objects, 10) +} + +func TestClient_PutObject_LargeBody(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + // 1 MB object + data := make([]byte, 1024*1024) + for i := range data { + data[i] = byte(i % 256) + } + err := client.PutObject(ctx, "large", bytes.NewReader(data), int64(len(data))) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "large") + require.Nil(t, err) + require.Equal(t, int64(1024*1024), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, data, got) +} + +func TestClient_PutObject_NestedKey(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 6) + require.Nil(t, err) + + reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt") + require.Nil(t, err) + data, _ := io.ReadAll(reader) + reader.Close() + require.Equal(t, "nested", string(data)) +} + +// --- Scale test: 20k objects (ntfy-adjacent) --- + +func TestClient_ListAllObjects_20k(t *testing.T) { + if testing.Short() { + t.Skip("skipping 20k object test in short mode") + } + + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "attachments") + + ctx := context.Background() + const numObjects = 20000 + const batchSize = 500 + + // Insert 20k objects in batches to keep it fast + for batch := 0; batch < numObjects/batchSize; batch++ { + for i := 0; i < batchSize; i++ { + idx := batch*batchSize + i + key := fmt.Sprintf("%08d", idx) + err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 1) + require.Nil(t, err) + } + } + + // List all 20k objects with pagination + objects, err := client.ListAllObjects(ctx) + require.Nil(t, err) + require.Len(t, objects, numObjects) + + // Verify total size + var totalSize int64 + for _, obj := range objects { + totalSize += obj.Size + } + require.Equal(t, int64(numObjects), totalSize) + + // Delete 1000 objects (simulating attachment expiry cleanup) + keys := make([]string, 1000) + for i := range keys { + keys[i] = fmt.Sprintf("%08d", i) + } + err = client.DeleteObjects(ctx, keys) + require.Nil(t, err) + + // List again: should have 19000 + objects, err = client.ListAllObjects(ctx) + require.Nil(t, err) + require.Len(t, objects, numObjects-1000) +} + +// --- Real S3 integration test --- +// +// Set the following environment variables to run this test against a real S3 bucket: +// +// S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET +// +// Optional: +// +// S3_ENDPOINT: host[:port] for S3-compatible providers (e.g. "nyc3.digitaloceanspaces.com") +// S3_PATH_STYLE: set to "true" for path-style addressing +// S3_PREFIX: key prefix to use (default: "ntfy-s3-test") +func TestClient_RealBucket(t *testing.T) { + accessKey := os.Getenv("S3_ACCESS_KEY") + secretKey := os.Getenv("S3_SECRET_KEY") + region := os.Getenv("S3_REGION") + bucket := os.Getenv("S3_BUCKET") + + if accessKey == "" || secretKey == "" || region == "" || bucket == "" { + t.Skip("skipping real S3 test: set S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET") + } + + endpoint := os.Getenv("S3_ENDPOINT") + if endpoint == "" { + endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region) + } + pathStyle := os.Getenv("S3_PATH_STYLE") == "true" + prefix := os.Getenv("S3_PREFIX") + if prefix == "" { + prefix = "ntfy-s3-test" + } + + client := &Client{ + AccessKey: accessKey, + SecretKey: secretKey, + Region: region, + Endpoint: endpoint, + Bucket: bucket, + Prefix: prefix, + PathStyle: pathStyle, + } + + ctx := context.Background() + + // Clean up any leftover objects from previous runs + existing, err := client.ListAllObjects(ctx) + require.Nil(t, err) + if len(existing) > 0 { + keys := make([]string, len(existing)) + for i, obj := range existing { + // Strip the prefix since DeleteObjects will re-add it + keys[i] = strings.TrimPrefix(obj.Key, prefix+"/") + } + // Batch delete in groups of 1000 + for i := 0; i < len(keys); i += 1000 { + end := i + 1000 + if end > len(keys) { + end = len(keys) + } + err := client.DeleteObjects(ctx, keys[i:end]) + require.Nil(t, err) + } + } + + t.Run("PutGetDelete", func(t *testing.T) { + key := "test-object" + content := "hello from ntfy s3 test" + + // Put + err := client.PutObject(ctx, key, strings.NewReader(content), int64(len(content))) + require.Nil(t, err) + + // Get + reader, size, err := client.GetObject(ctx, key) + require.Nil(t, err) + require.Equal(t, int64(len(content)), size) + data, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, content, string(data)) + + // Delete + err = client.DeleteObjects(ctx, []string{key}) + require.Nil(t, err) + + // Get after delete should fail + _, _, err = client.GetObject(ctx, key) + require.Error(t, err) + var errResp *ErrorResponse + require.ErrorAs(t, err, &errResp) + require.Equal(t, 404, errResp.StatusCode) + }) + + t.Run("ListObjects", func(t *testing.T) { + // Use a sub-prefix client for isolation + listClient := &Client{ + AccessKey: accessKey, + SecretKey: secretKey, + Region: region, + Endpoint: endpoint, + Bucket: bucket, + Prefix: prefix + "/list-test", + PathStyle: pathStyle, + } + + // Put 10 objects + for i := 0; i < 10; i++ { + err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 1) + require.Nil(t, err) + } + + // List + objects, err := listClient.ListAllObjects(ctx) + require.Nil(t, err) + require.Len(t, objects, 10) + + // Clean up + keys := make([]string, 10) + for i := range keys { + keys[i] = fmt.Sprintf("%d", i) + } + err = listClient.DeleteObjects(ctx, keys) + require.Nil(t, err) + }) + + t.Run("LargeObject", func(t *testing.T) { + key := "large-object" + data := make([]byte, 5*1024*1024) // 5 MB + for i := range data { + data[i] = byte(i % 256) + } + + err := client.PutObject(ctx, key, bytes.NewReader(data), int64(len(data))) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, key) + require.Nil(t, err) + require.Equal(t, int64(len(data)), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, data, got) + + err = client.DeleteObjects(ctx, []string{key}) + require.Nil(t, err) + }) +} diff --git a/s3/types.go b/s3/types.go new file mode 100644 index 00000000..5929ec6c --- /dev/null +++ b/s3/types.go @@ -0,0 +1,65 @@ +package s3 + +import "fmt" + +// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string. +type Config struct { + Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com" + PathStyle bool + Bucket string + Prefix string + Region string + AccessKey string + SecretKey string +} + +// Object represents an S3 object returned by list operations. +type Object struct { + Key string + Size int64 +} + +// ListResult holds the response from a ListObjectsV2 call. +type ListResult struct { + Objects []Object + IsTruncated bool + NextContinuationToken string +} + +// ErrorResponse is returned when S3 responds with a non-2xx status code. +type ErrorResponse struct { + StatusCode int + Code string `xml:"Code"` + Message string `xml:"Message"` + Body string `xml:"-"` // raw response body +} + +func (e *ErrorResponse) Error() string { + if e.Code != "" { + return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message) + } + return fmt.Sprintf("s3: HTTP %d: %s", e.StatusCode, e.Body) +} + +// listObjectsV2Response is the XML response from S3 ListObjectsV2 +type listObjectsV2Response struct { + Contents []listObject `xml:"Contents"` + IsTruncated bool `xml:"IsTruncated"` + NextContinuationToken string `xml:"NextContinuationToken"` +} + +type listObject struct { + Key string `xml:"Key"` + Size int64 `xml:"Size"` +} + +// deleteResult is the XML response from S3 DeleteObjects +type deleteResult struct { + Errors []deleteError `xml:"Error"` +} + +type deleteError struct { + Key string `xml:"Key"` + Code string `xml:"Code"` + Message string `xml:"Message"` +} diff --git a/s3/util.go b/s3/util.go new file mode 100644 index 00000000..cf9d4ba8 --- /dev/null +++ b/s3/util.go @@ -0,0 +1,161 @@ +package s3 + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strings" +) + +const ( + // SHA-256 hash of the empty string, used as the payload hash for bodiless requests + emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + + // Sent as the payload hash for streaming uploads where the body is not buffered in memory + unsignedPayload = "UNSIGNED-PAYLOAD" + + // maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB) + maxResponseBytes = 10 * 1024 * 1024 +) + +// ParseURL parses an S3 URL of the form: +// +// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] +// +// When endpoint is specified, path-style addressing is enabled automatically. +func ParseURL(s3URL string) (*Config, error) { + u, err := url.Parse(s3URL) + if err != nil { + return nil, fmt.Errorf("s3: invalid URL: %w", err) + } + if u.Scheme != "s3" { + return nil, fmt.Errorf("s3: URL scheme must be 's3', got '%s'", u.Scheme) + } + if u.Host == "" { + return nil, fmt.Errorf("s3: bucket name must be specified as host") + } + bucket := u.Host + prefix := strings.TrimPrefix(u.Path, "/") + accessKey := u.User.Username() + secretKey, _ := u.User.Password() + if accessKey == "" || secretKey == "" { + return nil, fmt.Errorf("s3: access key and secret key must be specified in URL") + } + region := u.Query().Get("region") + if region == "" { + return nil, fmt.Errorf("s3: region query parameter is required") + } + endpointParam := u.Query().Get("endpoint") + var endpoint string + var pathStyle bool + if endpointParam != "" { + // Custom endpoint: strip scheme prefix to extract host[:port] + ep := strings.TrimRight(endpointParam, "/") + ep = strings.TrimPrefix(ep, "https://") + ep = strings.TrimPrefix(ep, "http://") + endpoint = ep + pathStyle = true + } else { + endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region) + pathStyle = false + } + return &Config{ + Endpoint: endpoint, + PathStyle: pathStyle, + Bucket: bucket, + Prefix: prefix, + Region: region, + AccessKey: accessKey, + SecretKey: secretKey, + }, nil +} + +// parseError reads an S3 error response and returns an *ErrorResponse. +func parseError(resp *http.Response) error { + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return fmt.Errorf("s3: reading error response: %w", err) + } + return parseErrorFromBytes(resp.StatusCode, body) +} + +func parseErrorFromBytes(statusCode int, body []byte) error { + errResp := &ErrorResponse{ + StatusCode: statusCode, + Body: string(body), + } + // Try to parse XML error; if it fails, we still have StatusCode and Body + _ = xml.Unmarshal(body, errResp) + return errResp +} + +// canonicalURI returns the URI-encoded path for the canonical request. Each path segment is +// percent-encoded per RFC 3986; forward slashes are preserved. +func canonicalURI(u *url.URL) string { + p := u.Path + if p == "" { + return "/" + } + segments := strings.Split(p, "/") + for i, seg := range segments { + segments[i] = uriEncode(seg) + } + return strings.Join(segments, "/") +} + +// canonicalQueryString builds the query string for the canonical request. Keys and values +// are URI-encoded per RFC 3986 (using %20, not +) and sorted lexically by key. +func canonicalQueryString(values url.Values) string { + if len(values) == 0 { + return "" + } + keys := make([]string, 0, len(values)) + for k := range values { + keys = append(keys, k) + } + sort.Strings(keys) + var pairs []string + for _, k := range keys { + ek := uriEncode(k) + vs := make([]string, len(values[k])) + copy(vs, values[k]) + sort.Strings(vs) + for _, v := range vs { + pairs = append(pairs, ek+"="+uriEncode(v)) + } + } + return strings.Join(pairs, "&") +} + +// uriEncode percent-encodes a string per RFC 3986, encoding everything except unreserved +// characters (A-Z a-z 0-9 - _ . ~). +func uriEncode(s string) string { + var buf strings.Builder + for i := 0; i < len(s); i++ { + b := s[i] + if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') || + b == '-' || b == '_' || b == '.' || b == '~' { + buf.WriteByte(b) + } else { + fmt.Fprintf(&buf, "%%%02X", b) + } + } + return buf.String() +} + +func sha256Hex(data []byte) string { + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]) +} + +func hmacSHA256(key, data []byte) []byte { + h := hmac.New(sha256.New, key) + h.Write(data) + return h.Sum(nil) +} diff --git a/s3/util_test.go b/s3/util_test.go new file mode 100644 index 00000000..d30c5664 --- /dev/null +++ b/s3/util_test.go @@ -0,0 +1,181 @@ +package s3 + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestURIEncode(t *testing.T) { + // Unreserved characters are not encoded + require.Equal(t, "abcdefghijklmnopqrstuvwxyz", uriEncode("abcdefghijklmnopqrstuvwxyz")) + require.Equal(t, "ABCDEFGHIJKLMNOPQRSTUVWXYZ", uriEncode("ABCDEFGHIJKLMNOPQRSTUVWXYZ")) + require.Equal(t, "0123456789", uriEncode("0123456789")) + require.Equal(t, "-_.~", uriEncode("-_.~")) + + // Spaces use %20, not + + require.Equal(t, "hello%20world", uriEncode("hello world")) + + // Slashes are encoded (canonicalURI handles slash splitting separately) + require.Equal(t, "a%2Fb", uriEncode("a/b")) + + // Special characters + require.Equal(t, "%2B", uriEncode("+")) + require.Equal(t, "%3D", uriEncode("=")) + require.Equal(t, "%26", uriEncode("&")) + require.Equal(t, "%40", uriEncode("@")) + require.Equal(t, "%23", uriEncode("#")) + + // Mixed + require.Equal(t, "test~file-name_1.txt", uriEncode("test~file-name_1.txt")) + require.Equal(t, "key%20with%20spaces%2Fand%2Fslashes", uriEncode("key with spaces/and/slashes")) + + // Empty string + require.Equal(t, "", uriEncode("")) +} + +func TestCanonicalURI(t *testing.T) { + // Simple path + u, _ := url.Parse("https://example.com/bucket/key") + require.Equal(t, "/bucket/key", canonicalURI(u)) + + // Root path + u, _ = url.Parse("https://example.com/") + require.Equal(t, "/", canonicalURI(u)) + + // Empty path + u, _ = url.Parse("https://example.com") + require.Equal(t, "/", canonicalURI(u)) + + // Path with special characters + u, _ = url.Parse("https://example.com/bucket/key%20with%20spaces") + require.Equal(t, "/bucket/key%20with%20spaces", canonicalURI(u)) + + // Nested path + u, _ = url.Parse("https://example.com/bucket/a/b/c/file.txt") + require.Equal(t, "/bucket/a/b/c/file.txt", canonicalURI(u)) +} + +func TestCanonicalQueryString(t *testing.T) { + // Multiple keys sorted alphabetically + vals := url.Values{ + "prefix": {"test/"}, + "list-type": {"2"}, + } + require.Equal(t, "list-type=2&prefix=test%2F", canonicalQueryString(vals)) + + // Empty values + require.Equal(t, "", canonicalQueryString(url.Values{})) + + // Single key + require.Equal(t, "key=value", canonicalQueryString(url.Values{"key": {"value"}})) + + // Key with multiple values (sorted) + vals = url.Values{"key": {"b", "a"}} + require.Equal(t, "key=a&key=b", canonicalQueryString(vals)) + + // Keys requiring encoding + vals = url.Values{"continuation-token": {"abc+def"}} + require.Equal(t, "continuation-token=abc%2Bdef", canonicalQueryString(vals)) +} + +func TestSHA256Hex(t *testing.T) { + // SHA-256 of empty string + require.Equal(t, emptyPayloadHash, sha256Hex([]byte(""))) + + // SHA-256 of known value + require.Equal(t, "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", sha256Hex([]byte("hello"))) +} + +func TestHmacSHA256(t *testing.T) { + // Known test vector: HMAC-SHA256("key", "message") + result := hmacSHA256([]byte("key"), []byte("message")) + require.Len(t, result, 32) // SHA-256 produces 32 bytes + require.NotEqual(t, make([]byte, 32), result) + + // Same inputs should produce same output + result2 := hmacSHA256([]byte("key"), []byte("message")) + require.Equal(t, result, result2) + + // Different inputs should produce different output + result3 := hmacSHA256([]byte("different-key"), []byte("message")) + require.NotEqual(t, result, result3) +} + +func TestSignV4_SetsRequiredHeaders(t *testing.T) { + c := &Client{ + AccessKey: "AKID", + SecretKey: "SECRET", + Region: "us-east-1", + Endpoint: "s3.us-east-1.amazonaws.com", + Bucket: "my-bucket", + } + + req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil) + c.signV4(req, emptyPayloadHash) + + // All required SigV4 headers must be set + require.NotEmpty(t, req.Header.Get("Host")) + require.NotEmpty(t, req.Header.Get("X-Amz-Date")) + require.Equal(t, emptyPayloadHash, req.Header.Get("X-Amz-Content-Sha256")) + + // Authorization header must have correct format + auth := req.Header.Get("Authorization") + require.Contains(t, auth, "AWS4-HMAC-SHA256") + require.Contains(t, auth, "Credential=AKID/") + require.Contains(t, auth, "/us-east-1/s3/aws4_request") + require.Contains(t, auth, "SignedHeaders=") + require.Contains(t, auth, "Signature=") +} + +func TestSignV4_UnsignedPayload(t *testing.T) { + c := &Client{ + AccessKey: "AKID", + SecretKey: "SECRET", + Region: "us-east-1", + Endpoint: "s3.us-east-1.amazonaws.com", + Bucket: "my-bucket", + } + + req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil) + c.signV4(req, unsignedPayload) + + require.Equal(t, unsignedPayload, req.Header.Get("X-Amz-Content-Sha256")) +} + +func TestSignV4_DifferentRegions(t *testing.T) { + c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"} + c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"} + + req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil) + c1.signV4(req1, emptyPayloadHash) + + req2, _ := http.NewRequest(http.MethodGet, "https://b.s3.eu-west-1.amazonaws.com/key", nil) + c2.signV4(req2, emptyPayloadHash) + + // Different regions should produce different signatures + require.NotEqual(t, req1.Header.Get("Authorization"), req2.Header.Get("Authorization")) +} + +func TestParseError_XMLResponse(t *testing.T) { + xmlBody := []byte(`NoSuchKeyThe specified key does not exist.`) + err := parseErrorFromBytes(404, xmlBody) + + var errResp *ErrorResponse + require.ErrorAs(t, err, &errResp) + require.Equal(t, 404, errResp.StatusCode) + require.Equal(t, "NoSuchKey", errResp.Code) + require.Equal(t, "The specified key does not exist.", errResp.Message) +} + +func TestParseError_NonXMLResponse(t *testing.T) { + err := parseErrorFromBytes(500, []byte("internal server error")) + + var errResp *ErrorResponse + require.ErrorAs(t, err, &errResp) + require.Equal(t, 500, errResp.StatusCode) + require.Equal(t, "", errResp.Code) // XML parsing failed, no code + require.Contains(t, errResp.Body, "internal server error") +} diff --git a/tools/s3cli/main.go b/tools/s3cli/main.go new file mode 100644 index 00000000..697d4e71 --- /dev/null +++ b/tools/s3cli/main.go @@ -0,0 +1,164 @@ +// Command s3cli is a minimal CLI for testing the s3 package. It supports put, get, rm, and ls. +// +// Usage: +// +// export S3_URL="s3://ACCESS_KEY:SECRET_KEY@BUCKET/PREFIX?region=REGION&endpoint=ENDPOINT" +// +// s3cli put Upload a file +// s3cli put - Upload from stdin +// s3cli get Download to stdout +// s3cli rm [...] Delete one or more objects +// s3cli ls List all objects +package main + +import ( + "context" + "fmt" + "io" + "os" + "text/tabwriter" + + "heckel.io/ntfy/v2/s3" +) + +func main() { + if len(os.Args) < 2 { + usage() + } + s3URL := os.Getenv("S3_URL") + if s3URL == "" { + fail("S3_URL environment variable is required") + } + cfg, err := s3.ParseURL(s3URL) + if err != nil { + fail("invalid S3_URL: %s", err) + } + client := s3.New(cfg) + ctx := context.Background() + + switch os.Args[1] { + case "put": + cmdPut(ctx, client) + case "get": + cmdGet(ctx, client) + case "rm": + cmdRm(ctx, client) + case "ls": + cmdLs(ctx, client) + default: + usage() + } +} + +func cmdPut(ctx context.Context, client *s3.Client) { + if len(os.Args) != 4 { + fail("usage: s3cli put \n") + } + key := os.Args[2] + path := os.Args[3] + + var r io.Reader + var size int64 + if path == "-" { + // Read stdin into a temp file to get the size + tmp, err := os.CreateTemp("", "s3cli-*") + if err != nil { + fail("create temp file: %s", err) + } + defer os.Remove(tmp.Name()) + n, err := io.Copy(tmp, os.Stdin) + if err != nil { + tmp.Close() + fail("read stdin: %s", err) + } + if _, err := tmp.Seek(0, io.SeekStart); err != nil { + tmp.Close() + fail("seek: %s", err) + } + r = tmp + size = n + defer tmp.Close() + } else { + f, err := os.Open(path) + if err != nil { + fail("open %s: %s", path, err) + } + defer f.Close() + info, err := f.Stat() + if err != nil { + fail("stat %s: %s", path, err) + } + r = f + size = info.Size() + } + + if err := client.PutObject(ctx, key, r, size); err != nil { + fail("put: %s", err) + } + fmt.Fprintf(os.Stderr, "uploaded %s (%d bytes)\n", key, size) +} + +func cmdGet(ctx context.Context, client *s3.Client) { + if len(os.Args) != 3 { + fail("usage: s3cli get \n") + } + key := os.Args[2] + + reader, size, err := client.GetObject(ctx, key) + if err != nil { + fail("get: %s", err) + } + defer reader.Close() + n, err := io.Copy(os.Stdout, reader) + if err != nil { + fail("read: %s", err) + } + fmt.Fprintf(os.Stderr, "downloaded %s (%d bytes, content-length: %d)\n", key, n, size) +} + +func cmdRm(ctx context.Context, client *s3.Client) { + if len(os.Args) < 3 { + fail("usage: s3cli rm [...]\n") + } + keys := os.Args[2:] + if err := client.DeleteObjects(ctx, keys); err != nil { + fail("rm: %s", err) + } + fmt.Fprintf(os.Stderr, "deleted %d object(s)\n", len(keys)) +} + +func cmdLs(ctx context.Context, client *s3.Client) { + objects, err := client.ListAllObjects(ctx) + if err != nil { + fail("ls: %s", err) + } + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + var totalSize int64 + for _, obj := range objects { + fmt.Fprintf(w, "%d\t%s\n", obj.Size, obj.Key) + totalSize += obj.Size + } + w.Flush() + fmt.Fprintf(os.Stderr, "%d object(s), %d bytes total\n", len(objects), totalSize) +} + +func usage() { + fmt.Fprintf(os.Stderr, `Usage: s3cli [args...] + +Commands: + put Upload a file (use - for stdin) + get Download to stdout + rm [keys...] Delete objects + ls List all objects + +Environment: + S3_URL S3 connection URL (required) + s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] +`) + os.Exit(1) +} + +func fail(format string, args ...any) { + fmt.Fprintf(os.Stderr, format+"\n", args...) + os.Exit(1) +} From 86015e100c00b908dbd810f6f93e2a74befbc9ee Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Mon, 16 Mar 2026 20:00:19 -0400 Subject: [PATCH 04/32] Multipart upload --- .gitignore | 1 + attachment/store_s3.go | 49 +++++---- attachment/store_s3_test.go | 93 ++++++++++++++++- s3/client.go | 197 ++++++++++++++++++++++++++++++++---- s3/client_test.go | 167 ++++++++++++++++++++++++++---- s3/types.go | 11 ++ s3/util.go | 5 + tools/s3cli/main.go | 29 +----- 8 files changed, 462 insertions(+), 90 deletions(-) diff --git a/.gitignore b/.gitignore index ed17b2d4..6d5deb67 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ server/site/ tools/fbsend/fbsend tools/pgimport/pgimport tools/loadtest/loadtest +tools/s3cli/s3cli playground/ secrets/ *.iml diff --git a/attachment/store_s3.go b/attachment/store_s3.go index 5c47a81b..38f0353a 100644 --- a/attachment/store_s3.go +++ b/attachment/store_s3.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "os" "sync" "heckel.io/ntfy/v2/log" @@ -51,33 +50,31 @@ func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int6 } log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3") - // Write through limiters into a temp file. This avoids buffering the full attachment in - // memory while still giving us the Content-Length that PutObject requires. + // Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked + // uploads, so no temp file or Content-Length is needed. limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - tmpFile, err := os.CreateTemp("", "ntfy-s3-upload-*") - if err != nil { - return 0, fmt.Errorf("s3 store: failed to create temp file: %w", err) + pr, pw := io.Pipe() + lw := util.NewLimitWriter(pw, limiters...) + var size int64 + var copyErr error + done := make(chan struct{}) + go func() { + defer close(done) + size, copyErr = io.Copy(lw, in) + if copyErr != nil { + pw.CloseWithError(copyErr) + } else { + pw.Close() + } + }() + putErr := c.client.PutObject(context.Background(), id, pr) + pr.Close() + <-done + if copyErr != nil { + return 0, copyErr } - tmpPath := tmpFile.Name() - defer os.Remove(tmpPath) - limitWriter := util.NewLimitWriter(tmpFile, limiters...) - size, err := io.Copy(limitWriter, in) - if err != nil { - tmpFile.Close() - return 0, err - } - if err := tmpFile.Close(); err != nil { - return 0, err - } - - // Re-open the temp file for reading and stream it to S3 - f, err := os.Open(tmpPath) - if err != nil { - return 0, err - } - defer f.Close() - if err := c.client.PutObject(context.Background(), id, f, size); err != nil { - return 0, err + if putErr != nil { + return 0, putErr } c.mu.Lock() c.totalSizeCurrent += size diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index c898244d..872e8c23 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -167,27 +167,41 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body + objects map[string][]byte // full key (bucket/key) -> body + uploads map[string]map[int][]byte // uploadID -> partNumber -> data + nextID int // counter for generating upload IDs mu sync.RWMutex } func newMockS3Server() *httptest.Server { - m := &mockS3Server{objects: make(map[string][]byte)} + m := &mockS3Server{ + objects: make(map[string][]byte), + uploads: make(map[string]map[int][]byte), + } return httptest.NewTLSServer(m) } func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Path is /{bucket}[/{key...}] path := strings.TrimPrefix(r.URL.Path, "/") + q := r.URL.Query() switch { + case r.Method == http.MethodPut && q.Has("partNumber"): + m.handleUploadPart(w, r, path) case r.Method == http.MethodPut: m.handlePut(w, r, path) - case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2": + case r.Method == http.MethodPost && q.Has("uploads"): + m.handleInitiateMultipart(w, r, path) + case r.Method == http.MethodPost && q.Has("uploadId"): + m.handleCompleteMultipart(w, r, path) + case r.Method == http.MethodDelete && q.Has("uploadId"): + m.handleAbortMultipart(w, r, path) + case r.Method == http.MethodGet && q.Get("list-type") == "2": m.handleList(w, r, path) case r.Method == http.MethodGet: m.handleGet(w, r, path) - case r.Method == http.MethodPost && r.URL.Query().Has("delete"): + case r.Method == http.MethodPost && q.Has("delete"): m.handleDelete(w, r, path) default: http.Error(w, "not implemented", http.StatusNotImplemented) @@ -206,6 +220,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st w.WriteHeader(http.StatusOK) } +func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) { + m.mu.Lock() + m.nextID++ + uploadID := fmt.Sprintf("upload-%d", m.nextID) + m.uploads[uploadID] = make(map[int][]byte) + m.mu.Unlock() + + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `%s`, uploadID) +} + +func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + var partNumber int + fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber) + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + m.mu.Lock() + parts, ok := m.uploads[uploadID] + if !ok { + m.mu.Unlock() + http.Error(w, "NoSuchUpload", http.StatusNotFound) + return + } + parts[partNumber] = body + m.mu.Unlock() + + etag := fmt.Sprintf(`"etag-part-%d"`, partNumber) + w.Header().Set("ETag", etag) + w.WriteHeader(http.StatusOK) +} + +func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + + m.mu.Lock() + parts, ok := m.uploads[uploadID] + if !ok { + m.mu.Unlock() + http.Error(w, "NoSuchUpload", http.StatusNotFound) + return + } + + // Assemble parts in order + var assembled []byte + for i := 1; i <= len(parts); i++ { + assembled = append(assembled, parts[i]...) + } + m.objects[path] = assembled + delete(m.uploads, uploadID) + m.mu.Unlock() + + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `%s`, path) +} + +func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + m.mu.Lock() + delete(m.uploads, uploadID) + m.mu.Unlock() + w.WriteHeader(http.StatusNoContent) +} + func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { m.mu.RLock() body, ok := m.objects[path] diff --git a/s3/client.go b/s3/client.go index 7fdd8093..29e10d2d 100644 --- a/s3/client.go +++ b/s3/client.go @@ -10,6 +10,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/xml" + "errors" "fmt" "io" "net/http" @@ -50,25 +51,22 @@ func New(config *Config) *Client { } // PutObject uploads body to the given key. The key is automatically prefixed with the client's -// configured prefix. The body size must be known in advance. The payload is sent as -// UNSIGNED-PAYLOAD, which is supported by all major S3-compatible providers over HTTPS. -func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, size int64) error { - fullKey := c.objectKey(key) - req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body) +// configured prefix. The body size does not need to be known in advance. +// +// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request. +// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time +// into memory. +func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error { + first := make([]byte, partSize) + n, err := io.ReadFull(body, first) + if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF { + return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n)) + } if err != nil { - return fmt.Errorf("s3: PutObject request: %w", err) + return fmt.Errorf("s3: PutObject read: %w", err) } - req.ContentLength = size - c.signV4(req, unsignedPayload) - resp, err := c.httpClient().Do(req) - if err != nil { - return fmt.Errorf("s3: PutObject: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode/100 != 2 { - return parseError(resp) - } - return nil + combined := io.MultiReader(bytes.NewReader(first), body) + return c.putObjectMultipart(ctx, key, combined) } // GetObject downloads an object. The key is automatically prefixed with the client's configured @@ -216,6 +214,171 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) } +// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. +func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error { + fullKey := c.objectKey(key) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body) + if err != nil { + return fmt.Errorf("s3: PutObject request: %w", err) + } + req.ContentLength = size + c.signV4(req, unsignedPayload) + resp, err := c.httpClient().Do(req) + if err != nil { + return fmt.Errorf("s3: PutObject: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return parseError(resp) + } + return nil +} + +// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize +// chunks, uploading each as a separate part. This allows uploading without knowing the total +// body size in advance. +func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error { + fullKey := c.objectKey(key) + + // Step 1: Initiate multipart upload + uploadID, err := c.initiateMultipartUpload(ctx, fullKey) + if err != nil { + return err + } + + // Step 2: Upload parts + var parts []completedPart + buf := make([]byte, partSize) + partNumber := 1 + for { + n, err := io.ReadFull(body, buf) + if n > 0 { + etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n]) + if uploadErr != nil { + c.abortMultipartUpload(ctx, fullKey, uploadID) + return uploadErr + } + parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag}) + partNumber++ + } + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + if err != nil { + c.abortMultipartUpload(ctx, fullKey, uploadID) + return fmt.Errorf("s3: PutObject read: %w", err) + } + } + + // Step 3: Complete multipart upload + return c.completeMultipartUpload(ctx, fullKey, uploadID, parts) +} + +// initiateMultipartUpload starts a new multipart upload and returns the upload ID. +func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) { + reqURL := c.objectURL(fullKey) + "?uploads" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil) + if err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err) + } + req.ContentLength = 0 + c.signV4(req, emptyPayloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return "", parseError(resp) + } + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err) + } + var result initiateMultipartUploadResult + if err := xml.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) + } + return result.UploadID, nil +} + +// uploadPart uploads a single part of a multipart upload and returns the ETag. +func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) { + reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) + if err != nil { + return "", fmt.Errorf("s3: UploadPart request: %w", err) + } + req.ContentLength = int64(len(data)) + c.signV4(req, unsignedPayload) + resp, err := c.httpClient().Do(req) + if err != nil { + return "", fmt.Errorf("s3: UploadPart: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return "", parseError(resp) + } + etag := resp.Header.Get("ETag") + return etag, nil +} + +// completeMultipartUpload finalizes a multipart upload with the given parts. +func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error { + var body bytes.Buffer + body.WriteString("") + for _, p := range parts { + fmt.Fprintf(&body, "%d%s", p.PartNumber, p.ETag) + } + body.WriteString("") + bodyBytes := body.Bytes() + payloadHash := sha256Hex(bodyBytes) + + reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes)) + if err != nil { + return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err) + } + req.ContentLength = int64(len(bodyBytes)) + req.Header.Set("Content-Type", "application/xml") + c.signV4(req, payloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return fmt.Errorf("s3: CompleteMultipartUpload: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode/100 != 2 { + return parseError(resp) + } + // Read response body to check for errors (S3 can return 200 with an error body) + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err) + } + // Check if the response contains an error + var errResp ErrorResponse + if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" { + errResp.StatusCode = resp.StatusCode + return &errResp + } + return nil +} + +// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. +func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) { + reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) + if err != nil { + return + } + c.signV4(req, emptyPayloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return + } + resp.Body.Close() +} + // signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 // of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. func (c *Client) signV4(req *http.Request, payloadHash string) { diff --git a/s3/client_test.go b/s3/client_test.go index f4a10213..c3a8fe2c 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -23,27 +23,41 @@ import ( // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body + objects map[string][]byte // full key (bucket/key) -> body + uploads map[string]map[int][]byte // uploadID -> partNumber -> data + nextID int // counter for generating upload IDs mu sync.RWMutex } func newMockS3Server() (*httptest.Server, *mockS3Server) { - m := &mockS3Server{objects: make(map[string][]byte)} + m := &mockS3Server{ + objects: make(map[string][]byte), + uploads: make(map[string]map[int][]byte), + } return httptest.NewTLSServer(m), m } func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Path is /{bucket}[/{key...}] path := strings.TrimPrefix(r.URL.Path, "/") + q := r.URL.Query() switch { + case r.Method == http.MethodPut && q.Has("partNumber"): + m.handleUploadPart(w, r, path) case r.Method == http.MethodPut: m.handlePut(w, r, path) - case r.Method == http.MethodGet && r.URL.Query().Get("list-type") == "2": + case r.Method == http.MethodPost && q.Has("uploads"): + m.handleInitiateMultipart(w, r, path) + case r.Method == http.MethodPost && q.Has("uploadId"): + m.handleCompleteMultipart(w, r, path) + case r.Method == http.MethodDelete && q.Has("uploadId"): + m.handleAbortMultipart(w, r, path) + case r.Method == http.MethodGet && q.Get("list-type") == "2": m.handleList(w, r, path) case r.Method == http.MethodGet: m.handleGet(w, r, path) - case r.Method == http.MethodPost && r.URL.Query().Has("delete"): + case r.Method == http.MethodPost && q.Has("delete"): m.handleDelete(w, r, path) default: http.Error(w, "not implemented", http.StatusNotImplemented) @@ -62,6 +76,77 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st w.WriteHeader(http.StatusOK) } +func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) { + m.mu.Lock() + m.nextID++ + uploadID := fmt.Sprintf("upload-%d", m.nextID) + m.uploads[uploadID] = make(map[int][]byte) + m.mu.Unlock() + + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `%s`, uploadID) +} + +func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + var partNumber int + fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber) + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + m.mu.Lock() + parts, ok := m.uploads[uploadID] + if !ok { + m.mu.Unlock() + http.Error(w, "NoSuchUpload", http.StatusNotFound) + return + } + parts[partNumber] = body + m.mu.Unlock() + + etag := fmt.Sprintf(`"etag-part-%d"`, partNumber) + w.Header().Set("ETag", etag) + w.WriteHeader(http.StatusOK) +} + +func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + + m.mu.Lock() + parts, ok := m.uploads[uploadID] + if !ok { + m.mu.Unlock() + http.Error(w, "NoSuchUpload", http.StatusNotFound) + return + } + + // Assemble parts in order + var assembled []byte + for i := 1; i <= len(parts); i++ { + assembled = append(assembled, parts[i]...) + } + m.objects[path] = assembled + delete(m.uploads, uploadID) + m.mu.Unlock() + + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `%s`, path) +} + +func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) { + uploadID := r.URL.Query().Get("uploadId") + m.mu.Lock() + delete(m.uploads, uploadID) + m.mu.Unlock() + w.WriteHeader(http.StatusNoContent) +} + func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { m.mu.RLock() body, ok := m.objects[path] @@ -333,7 +418,7 @@ func TestClient_PutGetObject(t *testing.T) { ctx := context.Background() // Put - err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 11) + err := client.PutObject(ctx, "test-key", strings.NewReader("hello world")) require.Nil(t, err) // Get @@ -353,7 +438,7 @@ func TestClient_PutGetObject_WithPrefix(t *testing.T) { ctx := context.Background() - err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 5) + err := client.PutObject(ctx, "test-key", strings.NewReader("hello")) require.Nil(t, err) reader, _, err := client.GetObject(ctx, "test-key") @@ -385,7 +470,7 @@ func TestClient_DeleteObjects(t *testing.T) { // Put several objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 4) + err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data"))) require.Nil(t, err) } require.Equal(t, 5, mock.objectCount()) @@ -416,13 +501,13 @@ func TestClient_ListObjects(t *testing.T) { // Client with prefix "pfx": list should only return objects under pfx/ client := newTestClient(server, "my-bucket", "pfx") for i := 0; i < 3; i++ { - err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 1) + err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x"))) require.Nil(t, err) } // Also put an object outside the prefix using a no-prefix client clientNoPrefix := newTestClient(server, "my-bucket", "") - err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 1) + err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y"))) require.Nil(t, err) // List with prefix client: should only see 3 @@ -446,7 +531,7 @@ func TestClient_ListObjects_Pagination(t *testing.T) { // Put 5 objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1) + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x"))) require.Nil(t, err) } @@ -478,7 +563,7 @@ func TestClient_ListAllObjects(t *testing.T) { ctx := context.Background() for i := 0; i < 10; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 1) + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x"))) require.Nil(t, err) } @@ -499,7 +584,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) { for i := range data { data[i] = byte(i % 256) } - err := client.PutObject(ctx, "large", bytes.NewReader(data), int64(len(data))) + err := client.PutObject(ctx, "large", bytes.NewReader(data)) require.Nil(t, err) reader, size, err := client.GetObject(ctx, "large") @@ -511,6 +596,54 @@ func TestClient_PutObject_LargeBody(t *testing.T) { require.Equal(t, data, got) } +func TestClient_PutObject_ChunkedUpload(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + // 12 MB object, exceeds 5 MB partSize, triggers multipart upload path + data := make([]byte, 12*1024*1024) + for i := range data { + data[i] = byte(i % 256) + } + err := client.PutObject(ctx, "multipart", bytes.NewReader(data)) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "multipart") + require.Nil(t, err) + require.Equal(t, int64(12*1024*1024), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, data, got) +} + +func TestClient_PutObject_ExactPartSize(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "") + + ctx := context.Background() + + // Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully) + data := make([]byte, 5*1024*1024) + for i := range data { + data[i] = byte(i % 256) + } + err := client.PutObject(ctx, "exact", bytes.NewReader(data)) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "exact") + require.Nil(t, err) + require.Equal(t, int64(5*1024*1024), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, data, got) +} + func TestClient_PutObject_NestedKey(t *testing.T) { server, _ := newMockS3Server() defer server.Close() @@ -518,7 +651,7 @@ func TestClient_PutObject_NestedKey(t *testing.T) { ctx := context.Background() - err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 6) + err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested")) require.Nil(t, err) reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt") @@ -548,7 +681,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) { for i := 0; i < batchSize; i++ { idx := batch*batchSize + i key := fmt.Sprintf("%08d", idx) - err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 1) + err := client.PutObject(ctx, key, bytes.NewReader([]byte("x"))) require.Nil(t, err) } } @@ -647,7 +780,7 @@ func TestClient_RealBucket(t *testing.T) { content := "hello from ntfy s3 test" // Put - err := client.PutObject(ctx, key, strings.NewReader(content), int64(len(content))) + err := client.PutObject(ctx, key, strings.NewReader(content)) require.Nil(t, err) // Get @@ -685,7 +818,7 @@ func TestClient_RealBucket(t *testing.T) { // Put 10 objects for i := 0; i < 10; i++ { - err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 1) + err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x")) require.Nil(t, err) } @@ -710,7 +843,7 @@ func TestClient_RealBucket(t *testing.T) { data[i] = byte(i % 256) } - err := client.PutObject(ctx, key, bytes.NewReader(data), int64(len(data))) + err := client.PutObject(ctx, key, bytes.NewReader(data)) require.Nil(t, err) reader, size, err := client.GetObject(ctx, key) diff --git a/s3/types.go b/s3/types.go index 5929ec6c..201c570b 100644 --- a/s3/types.go +++ b/s3/types.go @@ -63,3 +63,14 @@ type deleteError struct { Code string `xml:"Code"` Message string `xml:"Message"` } + +// initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload +type initiateMultipartUploadResult struct { + UploadID string `xml:"UploadId"` +} + +// completedPart represents a successfully uploaded part for CompleteMultipartUpload +type completedPart struct { + PartNumber int + ETag string +} diff --git a/s3/util.go b/s3/util.go index cf9d4ba8..c24c1c5b 100644 --- a/s3/util.go +++ b/s3/util.go @@ -22,6 +22,11 @@ const ( // maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB) maxResponseBytes = 10 * 1024 * 1024 + + // partSize is the size of each part for multipart uploads (5 MB). This is also the threshold + // above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum + // part size of 5 MB for all parts except the last. + partSize = 5 * 1024 * 1024 ) // ParseURL parses an S3 URL of the form: diff --git a/tools/s3cli/main.go b/tools/s3cli/main.go index 697d4e71..1dbac0cf 100644 --- a/tools/s3cli/main.go +++ b/tools/s3cli/main.go @@ -58,44 +58,21 @@ func cmdPut(ctx context.Context, client *s3.Client) { path := os.Args[3] var r io.Reader - var size int64 if path == "-" { - // Read stdin into a temp file to get the size - tmp, err := os.CreateTemp("", "s3cli-*") - if err != nil { - fail("create temp file: %s", err) - } - defer os.Remove(tmp.Name()) - n, err := io.Copy(tmp, os.Stdin) - if err != nil { - tmp.Close() - fail("read stdin: %s", err) - } - if _, err := tmp.Seek(0, io.SeekStart); err != nil { - tmp.Close() - fail("seek: %s", err) - } - r = tmp - size = n - defer tmp.Close() + r = os.Stdin } else { f, err := os.Open(path) if err != nil { fail("open %s: %s", path, err) } defer f.Close() - info, err := f.Stat() - if err != nil { - fail("stat %s: %s", path, err) - } r = f - size = info.Size() } - if err := client.PutObject(ctx, key, r, size); err != nil { + if err := client.PutObject(ctx, key, r); err != nil { fail("put: %s", err) } - fmt.Fprintf(os.Stderr, "uploaded %s (%d bytes)\n", key, size) + fmt.Fprintf(os.Stderr, "uploaded %s\n", key) } func cmdGet(ctx context.Context, client *s3.Client) { From a47d692cbf0aa8d85b763f8d2d0edf87e9edc253 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Tue, 17 Mar 2026 07:50:28 -0400 Subject: [PATCH 05/32] Fix bug --- s3/client.go | 14 +++++++------- s3/util.go | 4 ++++ server/server.go | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/s3/client.go b/s3/client.go index 29e10d2d..ffc4fa8a 100644 --- a/s3/client.go +++ b/s3/client.go @@ -82,7 +82,7 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int6 if err != nil { return nil, 0, fmt.Errorf("s3: GetObject: %w", err) } - if resp.StatusCode/100 != 2 { + if !isHTTPSuccess(resp) { err := parseError(resp) resp.Body.Close() return nil, 0, err @@ -126,7 +126,7 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { return fmt.Errorf("s3: DeleteObjects: %w", err) } defer resp.Body.Close() - if resp.StatusCode/100 != 2 { + if !isHTTPSuccess(resp) { return parseError(resp) } @@ -176,7 +176,7 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK if err != nil { return nil, fmt.Errorf("s3: ListObjects read: %w", err) } - if resp.StatusCode/100 != 2 { + if !isHTTPSuccess(resp) { return nil, parseErrorFromBytes(resp.StatusCode, respBody) } var result listObjectsV2Response @@ -228,7 +228,7 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size return fmt.Errorf("s3: PutObject: %w", err) } defer resp.Body.Close() - if resp.StatusCode/100 != 2 { + if !isHTTPSuccess(resp) { return parseError(resp) } return nil @@ -288,7 +288,7 @@ func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (s return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err) } defer resp.Body.Close() - if resp.StatusCode/100 != 2 { + if !isHTTPSuccess(resp) { return "", parseError(resp) } respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) @@ -316,7 +316,7 @@ func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partN return "", fmt.Errorf("s3: UploadPart: %w", err) } defer resp.Body.Close() - if resp.StatusCode/100 != 2 { + if !isHTTPSuccess(resp) { return "", parseError(resp) } etag := resp.Header.Get("ETag") @@ -347,7 +347,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID return fmt.Errorf("s3: CompleteMultipartUpload: %w", err) } defer resp.Body.Close() - if resp.StatusCode/100 != 2 { + if !isHTTPSuccess(resp) { return parseError(resp) } // Read response body to check for errors (S3 can return 200 with an error body) diff --git a/s3/util.go b/s3/util.go index c24c1c5b..2e3fc233 100644 --- a/s3/util.go +++ b/s3/util.go @@ -154,6 +154,10 @@ func uriEncode(s string) string { return buf.String() } +func isHTTPSuccess(resp *http.Response) bool { + return resp.StatusCode/100 == 2 +} + func sha256Hex(data []byte) string { h := sha256.Sum256(data) return hex.EncodeToString(h[:]) diff --git a/server/server.go b/server/server.go index f77334de..0972f00d 100644 --- a/server/server.go +++ b/server/server.go @@ -603,7 +603,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.ensureWebEnabled(s.handleStatic)(w, r, v) } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { return s.ensureWebEnabled(s.handleDocs)(w, r, v) - } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" { + } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.fileCache != nil { return s.limitRequests(s.handleFile)(w, r, v) } else if r.Method == http.MethodOptions { return s.limitRequests(s.handleOptions)(w, r, v) // Should work even if the web app is not enabled, see #598 From cffa57950a50af7572e570e1a743582e6850a8b6 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Tue, 17 Mar 2026 16:25:45 -0400 Subject: [PATCH 06/32] Logs --- attachment/store.go | 10 +++++----- s3/client.go | 16 +++++++++++++++- s3/util.go | 3 +++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/attachment/store.go b/attachment/store.go index c48a1e90..302eb585 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -10,6 +10,11 @@ import ( "heckel.io/ntfy/v2/util" ) +var ( + fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength)) + errInvalidFileID = errors.New("invalid file ID") +) + // Store is an interface for storing and retrieving attachment files type Store interface { Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) @@ -18,8 +23,3 @@ type Store interface { Size() int64 Remaining() int64 } - -var ( - fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength)) - errInvalidFileID = errors.New("invalid file ID") -) diff --git a/s3/client.go b/s3/client.go index ffc4fa8a..56f9608e 100644 --- a/s3/client.go +++ b/s3/client.go @@ -19,6 +19,12 @@ import ( "strconv" "strings" "time" + + "heckel.io/ntfy/v2/log" +) + +const ( + tagS3Client = "s3_client" ) // Client is a minimal S3-compatible client. It supports PutObject, GetObject, DeleteObjects, @@ -60,11 +66,13 @@ func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) erro first := make([]byte, partSize) n, err := io.ReadFull(body, first) if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF { + log.Tag(tagS3Client).Debug("PutObject key=%s size=%d (simple)", key, n) return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n)) } if err != nil { return fmt.Errorf("s3: PutObject read: %w", err) } + log.Tag(tagS3Client).Debug("PutObject key=%s (multipart)", key) combined := io.MultiReader(bytes.NewReader(first), body) return c.putObjectMultipart(ctx, key, combined) } @@ -72,6 +80,7 @@ func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) erro // GetObject downloads an object. The key is automatically prefixed with the client's configured // prefix. The caller must close the returned ReadCloser. func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) { + log.Tag(tagS3Client).Debug("GetObject key=%s", key) fullKey := c.objectKey(key) req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(fullKey), nil) if err != nil { @@ -97,6 +106,7 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int6 // Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present // in the response, they are returned as a combined error. func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { + log.Tag(tagS3Client).Debug("DeleteObjects keys=%d", len(keys)) var body bytes.Buffer body.WriteString("true") for _, key := range keys { @@ -152,6 +162,7 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { // ListObjects performs a single ListObjectsV2 request using the client's configured prefix. // Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000). func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) { + log.Tag(tagS3Client).Debug("ListObjects continuation=%s maxKeys=%d", continuationToken, maxKeys) query := url.Values{"list-type": {"2"}} if prefix := c.prefixForList(); prefix != "" { query.Set("prefix", prefix) @@ -197,7 +208,6 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK // ListAllObjects returns all objects under the client's configured prefix by paginating through // ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve. func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { - const maxPages = 10000 var all []Object var token string for page := 0; page < maxPages; page++ { @@ -299,11 +309,13 @@ func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (s if err := xml.Unmarshal(respBody, &result); err != nil { return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) } + log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID) return result.UploadID, nil } // uploadPart uploads a single part of a multipart upload and returns the ETag. func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) { + log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data)) reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID)) req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) if err != nil { @@ -325,6 +337,7 @@ func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partN // completeMultipartUpload finalizes a multipart upload with the given parts. func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error { + log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts)) var body bytes.Buffer body.WriteString("") for _, p := range parts { @@ -366,6 +379,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID // abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) { + log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID) reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) if err != nil { diff --git a/s3/util.go b/s3/util.go index 2e3fc233..546a940a 100644 --- a/s3/util.go +++ b/s3/util.go @@ -27,6 +27,9 @@ const ( // above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum // part size of 5 MB for all parts except the last. partSize = 5 * 1024 * 1024 + + // maxPages is the max number of pages to iterate through when listing objects + maxPages = 10000 ) // ParseURL parses an S3 URL of the form: From ef314960d015d12fcab803848f4f8a6865d0ae50 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Tue, 17 Mar 2026 20:53:41 -0400 Subject: [PATCH 07/32] Refactor --- attachment/backend.go | 22 ++++ attachment/backend_file.go | 85 ++++++++++++++++ attachment/backend_s3.go | 70 +++++++++++++ attachment/store.go | 186 ++++++++++++++++++++++++++++++++-- attachment/store_file.go | 139 ------------------------- attachment/store_file_test.go | 105 ++++++++++++++----- attachment/store_s3.go | 150 --------------------------- attachment/store_s3_test.go | 180 +++++++++++++++++++++----------- message/cache.go | 76 ++++++-------- message/cache_postgres.go | 3 + message/cache_sqlite.go | 3 + s3/client.go | 10 +- s3/client_test.go | 3 +- s3/types.go | 15 ++- server/server.go | 16 ++- util/limit.go | 55 ++++++++++ 16 files changed, 682 insertions(+), 436 deletions(-) create mode 100644 attachment/backend.go create mode 100644 attachment/backend_file.go create mode 100644 attachment/backend_s3.go delete mode 100644 attachment/store_file.go delete mode 100644 attachment/store_s3.go diff --git a/attachment/backend.go b/attachment/backend.go new file mode 100644 index 00000000..8989b890 --- /dev/null +++ b/attachment/backend.go @@ -0,0 +1,22 @@ +package attachment + +import ( + "io" + "time" +) + +// backendObject represents an object stored in a backend. +type object struct { + ID string + Size int64 + LastModified time.Time +} + +// backend is a minimal I/O interface for storing and retrieving attachment files. +// It has no knowledge of size tracking, limiting, or ID validation. +type backend interface { + Put(id string, in io.Reader) error + Get(id string) (io.ReadCloser, int64, error) + Delete(ids ...string) error + List() ([]object, error) +} diff --git a/attachment/backend_file.go b/attachment/backend_file.go new file mode 100644 index 00000000..b0afb2ca --- /dev/null +++ b/attachment/backend_file.go @@ -0,0 +1,85 @@ +package attachment + +import ( + "io" + "os" + "path/filepath" + + "heckel.io/ntfy/v2/log" +) + +const tagFileBackend = "file_backend" + +type fileBackend struct { + dir string +} + +var _ backend = (*fileBackend)(nil) + +func newFileBackend(dir string) (*fileBackend, error) { + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, err + } + return &fileBackend{dir: dir}, nil +} + +func (b *fileBackend) Put(id string, in io.Reader) error { + file := filepath.Join(b.dir, id) + f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer f.Close() + if _, err := io.Copy(f, in); err != nil { + os.Remove(file) + return err + } + if err := f.Close(); err != nil { + os.Remove(file) + return err + } + return nil +} + +func (b *fileBackend) Get(id string) (io.ReadCloser, int64, error) { + file := filepath.Join(b.dir, id) + stat, err := os.Stat(file) + if err != nil { + return nil, 0, err + } + f, err := os.Open(file) + if err != nil { + return nil, 0, err + } + return f, stat.Size(), nil +} + +func (b *fileBackend) Delete(ids ...string) error { + for _, id := range ids { + file := filepath.Join(b.dir, id) + if err := os.Remove(file); err != nil { + log.Tag(tagFileBackend).Field("message_id", id).Err(err).Debug("Error deleting attachment") + } + } + return nil +} + +func (b *fileBackend) List() ([]object, error) { + entries, err := os.ReadDir(b.dir) + if err != nil { + return nil, err + } + objects := make([]object, 0, len(entries)) + for _, e := range entries { + info, err := e.Info() + if err != nil { + return nil, err + } + objects = append(objects, object{ + ID: e.Name(), + Size: info.Size(), + LastModified: info.ModTime(), + }) + } + return objects, nil +} diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go new file mode 100644 index 00000000..8fcd8ccb --- /dev/null +++ b/attachment/backend_s3.go @@ -0,0 +1,70 @@ +package attachment + +import ( + "context" + "io" + "strings" + + "heckel.io/ntfy/v2/s3" + + "heckel.io/ntfy/v2/log" +) + +const tagS3Backend = "s3_backend" + +type s3Backend struct { + client *s3.Client +} + +var _ backend = (*s3Backend)(nil) + +func newS3Backend(client *s3.Client) *s3Backend { + return &s3Backend{client: client} +} + +func (b *s3Backend) Put(id string, in io.Reader) error { + return b.client.PutObject(context.Background(), id, in) +} + +func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) { + return b.client.GetObject(context.Background(), id) +} + +func (b *s3Backend) Delete(ids ...string) error { + // S3 DeleteObjects supports up to 1000 keys per call + for i := 0; i < len(ids); i += 1000 { + end := i + 1000 + if end > len(ids) { + end = len(ids) + } + batch := ids[i:end] + for _, id := range batch { + log.Tag(tagS3Backend).Field("message_id", id).Debug("Deleting attachment from S3") + } + if err := b.client.DeleteObjects(context.Background(), batch); err != nil { + return err + } + } + return nil +} + +func (b *s3Backend) List() ([]object, error) { + objects, err := b.client.ListAllObjects(context.Background()) + if err != nil { + return nil, err + } + prefix := b.client.Prefix + result := make([]object, 0, len(objects)) + for _, obj := range objects { + id := obj.Key + if prefix != "" { + id = strings.TrimPrefix(id, prefix+"/") + } + result = append(result, object{ + ID: id, + Size: obj.Size, + LastModified: obj.LastModified, + }) + } + return result, nil +} diff --git a/attachment/store.go b/attachment/store.go index 302eb585..f66cd35c 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -5,21 +5,193 @@ import ( "fmt" "io" "regexp" + "sync" + "time" + "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/model" + "heckel.io/ntfy/v2/s3" "heckel.io/ntfy/v2/util" ) +const ( + tagStore = "attachment_cache" + syncInterval = 15 * time.Minute // How often to run the background sync loop + orphanGracePeriod = time.Hour // Don't delete orphaned objects younger than this to avoid races with in-flight uploads +) + var ( fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, model.MessageIDLength)) errInvalidFileID = errors.New("invalid file ID") ) -// Store is an interface for storing and retrieving attachment files -type Store interface { - Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) - Read(id string) (io.ReadCloser, int64, error) - Remove(ids ...string) error - Size() int64 - Remaining() int64 +// Store manages attachment storage with shared logic for size tracking, limiting, +// ID validation, and background sync to reconcile storage with the database. +type Store struct { + backend backend + totalSizeCurrent int64 + totalSizeLimit int64 + localIDs func() ([]string, error) // returns IDs that should exist + closeChan chan struct{} + mu sync.Mutex // Protects totalSizeCurrent +} + +// NewFileStore creates a new file-system backed attachment cache +func NewFileStore(dir string, totalSizeLimit int64, localIDsFn func() ([]string, error)) (*Store, error) { + backend, err := newFileBackend(dir) + if err != nil { + return nil, err + } + return newStore(backend, totalSizeLimit, localIDsFn) +} + +// NewS3Store creates a new S3-backed attachment cache. The s3URL must be in the format: +// +// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] +func NewS3Store(s3URL string, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) { + config, err := s3.ParseURL(s3URL) + if err != nil { + return nil, err + } + return newStore(newS3Backend(s3.New(config)), totalSizeLimit, localIDs) +} + +func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) { + c := &Store{ + backend: backend, + totalSizeLimit: totalSizeLimit, + localIDs: localIDs, + closeChan: make(chan struct{}), + } + if localIDs != nil { + go c.syncLoop() + } + return c, nil +} + +// Write stores an attachment file. The id is validated, and the write is subject to +// the total size limit and any additional limiters. +func (c *Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { + if !fileIDRegex.MatchString(id) { + return 0, errInvalidFileID + } + log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment") + limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) + cr := util.NewCountingReader(in) + lr := util.NewLimitReader(cr, limiters...) + if err := c.backend.Put(id, lr); err != nil { + c.backend.Delete(id) //nolint:errcheck + return 0, err + } + size := cr.Total() + c.mu.Lock() + c.totalSizeCurrent += size + c.mu.Unlock() + return size, nil +} + +// Read retrieves an attachment file by ID +func (c *Store) Read(id string) (io.ReadCloser, int64, error) { + if !fileIDRegex.MatchString(id) { + return nil, 0, errInvalidFileID + } + return c.backend.Get(id) +} + +// Remove deletes attachment files by ID. It does NOT recompute the total size; +// the next sync() call will correct it. +func (c *Store) Remove(ids ...string) error { + for _, id := range ids { + if !fileIDRegex.MatchString(id) { + return errInvalidFileID + } + } + return c.backend.Delete(ids...) +} + +// sync reconciles the backend storage with the database. It lists all objects, +// deletes orphans (not in the valid ID set and older than 1 hour), and recomputes +// the total size from the remaining objects. +func (c *Store) sync() error { + log.Tag(tagStore).Debug("Sync: starting sync loop") + localIDs, err := c.localIDs() + if err != nil { + return fmt.Errorf("attachment sync: failed to get valid IDs: %w", err) + } + localIDMap := make(map[string]struct{}, len(localIDs)) + for _, id := range localIDs { + localIDMap[id] = struct{}{} + } + remoteObjects, err := c.backend.List() + if err != nil { + return fmt.Errorf("attachment sync: failed to list objects: %w", err) + } + // Calculate total cache size and collect orphaned attachments, excluding objects younger + // than the grace period to account for races, and skipping objects with invalid IDs. + cutoff := time.Now().Add(-orphanGracePeriod) + var orphanIDs []string + var totalSize int64 + for _, obj := range remoteObjects { + if !fileIDRegex.MatchString(obj.ID) { + continue + } + if _, ok := localIDMap[obj.ID]; !ok && obj.LastModified.Before(cutoff) { + orphanIDs = append(orphanIDs, obj.ID) + } else { + totalSize += obj.Size + } + } + log.Tag(tagStore).Debug("Sync: cache size updated to %s", util.FormatSizeHuman(totalSize)) + c.mu.Lock() + c.totalSizeCurrent = totalSize + c.mu.Unlock() + // Delete orphaned attachments + if len(orphanIDs) > 0 { + log.Tag(tagStore).Debug("Sync: deleting %d orphaned attachment(s)", len(orphanIDs)) + if err := c.backend.Delete(orphanIDs...); err != nil { + return fmt.Errorf("attachment sync: failed to delete orphaned objects: %w", err) + } + } + return nil +} + +// Size returns the current total size of all attachments +func (c *Store) Size() int64 { + c.mu.Lock() + defer c.mu.Unlock() + return c.totalSizeCurrent +} + +// Remaining returns the remaining capacity for attachments +func (c *Store) Remaining() int64 { + c.mu.Lock() + defer c.mu.Unlock() + remaining := c.totalSizeLimit - c.totalSizeCurrent + if remaining < 0 { + return 0 + } + return remaining +} + +// Close stops the background sync goroutine +func (c *Store) Close() { + close(c.closeChan) +} + +func (c *Store) syncLoop() { + if err := c.sync(); err != nil { + log.Tag(tagStore).Err(err).Warn("Attachment sync failed") + } + ticker := time.NewTicker(syncInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := c.sync(); err != nil { + log.Tag(tagStore).Err(err).Warn("Attachment sync failed") + } + case <-c.closeChan: + return + } + } } diff --git a/attachment/store_file.go b/attachment/store_file.go deleted file mode 100644 index b26e86f0..00000000 --- a/attachment/store_file.go +++ /dev/null @@ -1,139 +0,0 @@ -package attachment - -import ( - "errors" - "io" - "os" - "path/filepath" - "sync" - - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/util" -) - -const tagFileStore = "file_store" - -var errFileExists = errors.New("file exists") - -type fileStore struct { - dir string - totalSizeCurrent int64 - totalSizeLimit int64 - mu sync.Mutex -} - -// NewFileStore creates a new file-system backed attachment store -func NewFileStore(dir string, totalSizeLimit int64) (Store, error) { - if err := os.MkdirAll(dir, 0700); err != nil { - return nil, err - } - size, err := dirSize(dir) - if err != nil { - return nil, err - } - return &fileStore{ - dir: dir, - totalSizeCurrent: size, - totalSizeLimit: totalSizeLimit, - }, nil -} - -func (c *fileStore) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { - if !fileIDRegex.MatchString(id) { - return 0, errInvalidFileID - } - log.Tag(tagFileStore).Field("message_id", id).Debug("Writing attachment") - file := filepath.Join(c.dir, id) - if _, err := os.Stat(file); err == nil { - return 0, errFileExists - } - f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) - if err != nil { - return 0, err - } - defer f.Close() - limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - limitWriter := util.NewLimitWriter(f, limiters...) - size, err := io.Copy(limitWriter, in) - if err != nil { - os.Remove(file) - return 0, err - } - if err := f.Close(); err != nil { - os.Remove(file) - return 0, err - } - c.mu.Lock() - c.totalSizeCurrent += size - c.mu.Unlock() - return size, nil -} - -func (c *fileStore) Read(id string) (io.ReadCloser, int64, error) { - if !fileIDRegex.MatchString(id) { - return nil, 0, errInvalidFileID - } - file := filepath.Join(c.dir, id) - stat, err := os.Stat(file) - if err != nil { - return nil, 0, err - } - f, err := os.Open(file) - if err != nil { - return nil, 0, err - } - return f, stat.Size(), nil -} - -func (c *fileStore) Remove(ids ...string) error { - for _, id := range ids { - if !fileIDRegex.MatchString(id) { - return errInvalidFileID - } - log.Tag(tagFileStore).Field("message_id", id).Debug("Deleting attachment") - file := filepath.Join(c.dir, id) - if err := os.Remove(file); err != nil { - log.Tag(tagFileStore).Field("message_id", id).Err(err).Debug("Error deleting attachment") - } - } - size, err := dirSize(c.dir) - if err != nil { - return err - } - c.mu.Lock() - c.totalSizeCurrent = size - c.mu.Unlock() - return nil -} - -func (c *fileStore) Size() int64 { - c.mu.Lock() - defer c.mu.Unlock() - return c.totalSizeCurrent -} - -func (c *fileStore) Remaining() int64 { - c.mu.Lock() - defer c.mu.Unlock() - remaining := c.totalSizeLimit - c.totalSizeCurrent - if remaining < 0 { - return 0 - } - return remaining -} - -func dirSize(dir string) (int64, error) { - entries, err := os.ReadDir(dir) - if err != nil { - return 0, err - } - var size int64 - for _, e := range entries { - info, err := e.Info() - if err != nil { - return 0, err - } - size += info.Size() - } - return size, nil -} diff --git a/attachment/store_file_test.go b/attachment/store_file_test.go index 5cfe0db4..ceac09d7 100644 --- a/attachment/store_file_test.go +++ b/attachment/store_file_test.go @@ -7,6 +7,7 @@ import ( "os" "strings" "testing" + "time" "github.com/stretchr/testify/require" "heckel.io/ntfy/v2/util" @@ -17,22 +18,22 @@ var ( ) func TestFileStore_Write_Success(t *testing.T) { - dir, s := newTestFileStore(t) - size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) + dir, c := newTestFileStore(t) + size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) require.Nil(t, err) require.Equal(t, int64(11), size) require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) - require.Equal(t, int64(11), s.Size()) - require.Equal(t, int64(10229), s.Remaining()) + require.Equal(t, int64(11), c.Size()) + require.Equal(t, int64(10229), c.Remaining()) } func TestFileStore_Write_Read_Success(t *testing.T) { - _, s := newTestFileStore(t) - size, err := s.Write("abcdefghijkl", strings.NewReader("hello world")) + _, c := newTestFileStore(t) + size, err := c.Write("abcdefghijkl", strings.NewReader("hello world")) require.Nil(t, err) require.Equal(t, int64(11), size) - reader, readSize, err := s.Read("abcdefghijkl") + reader, readSize, err := c.Read("abcdefghijkl") require.Nil(t, err) require.Equal(t, int64(11), readSize) defer reader.Close() @@ -42,57 +43,111 @@ func TestFileStore_Write_Read_Success(t *testing.T) { } func TestFileStore_Write_Remove_Success(t *testing.T) { - dir, s := newTestFileStore(t) // max = 10k (10240), each = 1k (1024) + dir, c := newTestFileStore(t) // max = 10k (10240), each = 1k (1024) for i := 0; i < 10; i++ { // 10x999 = 9990 - size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) require.Nil(t, err) require.Equal(t, int64(999), size) } - require.Equal(t, int64(9990), s.Size()) - require.Equal(t, int64(250), s.Remaining()) + require.Equal(t, int64(9990), c.Size()) + require.Equal(t, int64(250), c.Remaining()) require.FileExists(t, dir+"/abcdefghijk1") require.FileExists(t, dir+"/abcdefghijk5") - require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk5")) + require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5")) require.NoFileExists(t, dir+"/abcdefghijk1") require.NoFileExists(t, dir+"/abcdefghijk5") - require.Equal(t, int64(7992), s.Size()) - require.Equal(t, int64(2248), s.Remaining()) + // Size is not recomputed by Remove; it stays stale until next sync + require.Equal(t, int64(9990), c.Size()) } func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) { - dir, s := newTestFileStore(t) + dir, c := newTestFileStore(t) for i := 0; i < 10; i++ { - size, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) require.Nil(t, err) require.Equal(t, int64(1024), size) } - _, err := s.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) + _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abcdefghijkX") } func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) { - dir, s := newTestFileStore(t) - _, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) + dir, c := newTestFileStore(t) + _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abcdefghijkl") } func TestFileStore_Read_NotFound(t *testing.T) { - _, s := newTestFileStore(t) - _, _, err := s.Read("abcdefghijkl") + _, c := newTestFileStore(t) + _, _, err := c.Read("abcdefghijkl") require.Error(t, err) } -func newTestFileStore(t *testing.T) (dir string, store Store) { - dir = t.TempDir() - store, err := NewFileStore(dir, 10*1024) +func TestFileStore_Sync(t *testing.T) { + dir, c := newTestFileStore(t) + + // Write some files + _, err := c.Write("abcdefghijk0", strings.NewReader("file0")) require.Nil(t, err) - return dir, store + _, err = c.Write("abcdefghijk1", strings.NewReader("file1")) + require.Nil(t, err) + _, err = c.Write("abcdefghijk2", strings.NewReader("file2")) + require.Nil(t, err) + + require.Equal(t, int64(15), c.Size()) + + // Set the ID provider to only know about file 0 and 2 + c.localIDs = func() ([]string, error) { + return []string{"abcdefghijk0", "abcdefghijk2"}, nil + } + + // Make file 1's mod time old enough to be cleaned up (> 1 hour) + oldTime := time.Unix(1, 0) + os.Chtimes(dir+"/abcdefghijk1", oldTime, oldTime) + + // Run sync + require.Nil(t, c.sync()) + + // File 1 should be deleted (orphan, old enough) + require.NoFileExists(t, dir+"/abcdefghijk1") + require.FileExists(t, dir+"/abcdefghijk0") + require.FileExists(t, dir+"/abcdefghijk2") + + // Size should be updated + require.Equal(t, int64(10), c.Size()) +} + +func TestFileStore_Sync_SkipsRecentFiles(t *testing.T) { + dir, c := newTestFileStore(t) + + // Write a file + _, err := c.Write("abcdefghijk0", strings.NewReader("file0")) + require.Nil(t, err) + + // Set the ID provider to return empty (no valid IDs) + c.localIDs = func() ([]string, error) { + return []string{}, nil + } + + // File was just created, so it should NOT be deleted (< 1 hour old) + require.Nil(t, c.sync()) + require.FileExists(t, dir+"/abcdefghijk0") +} + +func newTestFileStore(t *testing.T) (dir string, cache *Store) { + t.Helper() + dir = t.TempDir() + cache, err := NewFileStore(dir, 10*1024, nil) + require.Nil(t, err) + t.Cleanup(func() { cache.Close() }) + return dir, cache } func readFile(t *testing.T, f string) string { + t.Helper() b, err := os.ReadFile(f) require.Nil(t, err) return string(b) diff --git a/attachment/store_s3.go b/attachment/store_s3.go deleted file mode 100644 index 38f0353a..00000000 --- a/attachment/store_s3.go +++ /dev/null @@ -1,150 +0,0 @@ -package attachment - -import ( - "context" - "fmt" - "io" - "sync" - - "heckel.io/ntfy/v2/log" - "heckel.io/ntfy/v2/s3" - "heckel.io/ntfy/v2/util" -) - -const ( - tagS3Store = "s3_store" -) - -type s3Store struct { - client *s3.Client - totalSizeCurrent int64 - totalSizeLimit int64 - mu sync.Mutex -} - -// NewS3Store creates a new S3-backed attachment store. The s3URL must be in the format: -// -// s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] -func NewS3Store(s3URL string, totalSizeLimit int64) (Store, error) { - cfg, err := s3.ParseURL(s3URL) - if err != nil { - return nil, err - } - store := &s3Store{ - client: s3.New(cfg), - totalSizeLimit: totalSizeLimit, - } - if totalSizeLimit > 0 { - size, err := store.computeSize() - if err != nil { - return nil, fmt.Errorf("s3 store: failed to compute initial size: %w", err) - } - store.totalSizeCurrent = size - } - return store, nil -} - -func (c *s3Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { - if !fileIDRegex.MatchString(id) { - return 0, errInvalidFileID - } - log.Tag(tagS3Store).Field("message_id", id).Debug("Writing attachment to S3") - - // Stream through limiters via an io.Pipe directly to S3. PutObject supports chunked - // uploads, so no temp file or Content-Length is needed. - limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - pr, pw := io.Pipe() - lw := util.NewLimitWriter(pw, limiters...) - var size int64 - var copyErr error - done := make(chan struct{}) - go func() { - defer close(done) - size, copyErr = io.Copy(lw, in) - if copyErr != nil { - pw.CloseWithError(copyErr) - } else { - pw.Close() - } - }() - putErr := c.client.PutObject(context.Background(), id, pr) - pr.Close() - <-done - if copyErr != nil { - return 0, copyErr - } - if putErr != nil { - return 0, putErr - } - c.mu.Lock() - c.totalSizeCurrent += size - c.mu.Unlock() - return size, nil -} - -func (c *s3Store) Read(id string) (io.ReadCloser, int64, error) { - if !fileIDRegex.MatchString(id) { - return nil, 0, errInvalidFileID - } - return c.client.GetObject(context.Background(), id) -} - -func (c *s3Store) Remove(ids ...string) error { - for _, id := range ids { - if !fileIDRegex.MatchString(id) { - return errInvalidFileID - } - } - // S3 DeleteObjects supports up to 1000 keys per call - for i := 0; i < len(ids); i += 1000 { - end := i + 1000 - if end > len(ids) { - end = len(ids) - } - batch := ids[i:end] - for _, id := range batch { - log.Tag(tagS3Store).Field("message_id", id).Debug("Deleting attachment from S3") - } - if err := c.client.DeleteObjects(context.Background(), batch); err != nil { - return err - } - } - // Recalculate totalSizeCurrent via ListObjectsV2 (matches fileStore's dirSize rescan pattern) - size, err := c.computeSize() - if err != nil { - return fmt.Errorf("s3 store: failed to compute size after remove: %w", err) - } - c.mu.Lock() - c.totalSizeCurrent = size - c.mu.Unlock() - return nil -} - -func (c *s3Store) Size() int64 { - c.mu.Lock() - defer c.mu.Unlock() - return c.totalSizeCurrent -} - -func (c *s3Store) Remaining() int64 { - c.mu.Lock() - defer c.mu.Unlock() - remaining := c.totalSizeLimit - c.totalSizeCurrent - if remaining < 0 { - return 0 - } - return remaining -} - -// computeSize uses ListAllObjects to sum up the total size of all objects with our prefix. -func (c *s3Store) computeSize() (int64, error) { - objects, err := c.client.ListAllObjects(context.Background()) - if err != nil { - return 0, err - } - var totalSize int64 - for _, obj := range objects { - totalSize += obj.Size - } - return totalSize, nil -} diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 872e8c23..e29f9ed6 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -10,6 +10,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/stretchr/testify/require" "heckel.io/ntfy/v2/s3" @@ -22,16 +23,16 @@ func TestS3Store_WriteReadRemove(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) // Write - size, err := store.Write("abcdefghijkl", strings.NewReader("hello world")) + size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world")) require.Nil(t, err) require.Equal(t, int64(11), size) - require.Equal(t, int64(11), store.Size()) + require.Equal(t, int64(11), cache.Size()) // Read back - reader, readSize, err := store.Read("abcdefghijkl") + reader, readSize, err := cache.Read("abcdefghijkl") require.Nil(t, err) require.Equal(t, int64(11), readSize) data, err := io.ReadAll(reader) @@ -40,11 +41,11 @@ func TestS3Store_WriteReadRemove(t *testing.T) { require.Equal(t, "hello world", string(data)) // Remove - require.Nil(t, store.Remove("abcdefghijkl")) - require.Equal(t, int64(0), store.Size()) + require.Nil(t, cache.Remove("abcdefghijkl")) + // Size is not recomputed by Remove; stays stale until next sync // Read after remove should fail - _, _, err = store.Read("abcdefghijkl") + _, _, err = cache.Read("abcdefghijkl") require.Error(t, err) } @@ -52,13 +53,13 @@ func TestS3Store_WriteNoPrefix(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "", 10*1024) - size, err := store.Write("abcdefghijkl", strings.NewReader("test")) + size, err := cache.Write("abcdefghijkl", strings.NewReader("test")) require.Nil(t, err) require.Equal(t, int64(4), size) - reader, _, err := store.Read("abcdefghijkl") + reader, _, err := cache.Read("abcdefghijkl") require.Nil(t, err) data, err := io.ReadAll(reader) reader.Close() @@ -70,52 +71,53 @@ func TestS3Store_WriteTotalSizeLimit(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 100) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 100) // First write fits - _, err := store.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80))) + _, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80))) require.Nil(t, err) - require.Equal(t, int64(80), store.Size()) - require.Equal(t, int64(20), store.Remaining()) + require.Equal(t, int64(80), cache.Size()) + require.Equal(t, int64(20), cache.Remaining()) // Second write exceeds total limit - _, err = store.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50))) - require.Equal(t, util.ErrLimitReached, err) + _, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50))) + require.ErrorIs(t, err, util.ErrLimitReached) } func TestS3Store_WriteFileSizeLimit(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, err := store.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100)) - require.Equal(t, util.ErrLimitReached, err) + _, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100)) + require.ErrorIs(t, err, util.ErrLimitReached) } func TestS3Store_WriteRemoveMultiple(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) for i := 0; i < 5; i++ { - _, err := store.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100))) + _, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100))) require.Nil(t, err) } - require.Equal(t, int64(500), store.Size()) + require.Equal(t, int64(500), cache.Size()) - require.Nil(t, store.Remove("abcdefghijk1", "abcdefghijk3")) - require.Equal(t, int64(300), store.Size()) + require.Nil(t, cache.Remove("abcdefghijk1", "abcdefghijk3")) + // Size not recomputed by Remove + require.Equal(t, int64(500), cache.Size()) } func TestS3Store_ReadNotFound(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, _, err := store.Read("abcdefghijkl") + _, _, err := cache.Read("abcdefghijkl") require.Error(t, err) } @@ -123,42 +125,93 @@ func TestS3Store_InvalidID(t *testing.T) { server := newMockS3Server() defer server.Close() - store := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, err := store.Write("bad", strings.NewReader("x")) + _, err := cache.Write("bad", strings.NewReader("x")) require.Equal(t, errInvalidFileID, err) - _, _, err = store.Read("bad") + _, _, err = cache.Read("bad") require.Equal(t, errInvalidFileID, err) - err = store.Remove("bad") + err = cache.Remove("bad") require.Equal(t, errInvalidFileID, err) } +func TestS3Store_Sync(t *testing.T) { + server := newMockS3Server() + defer server.Close() + + cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) + + // Write some files + _, err := cache.Write("abcdefghijk0", strings.NewReader("file0")) + require.Nil(t, err) + _, err = cache.Write("abcdefghijk1", strings.NewReader("file1")) + require.Nil(t, err) + _, err = cache.Write("abcdefghijk2", strings.NewReader("file2")) + require.Nil(t, err) + + require.Equal(t, int64(15), cache.Size()) + + // Set the ID provider to only know about file 0 and 2 + // All mock objects have LastModified set to 2 hours ago, so orphans are eligible for deletion + cache.localIDs = func() ([]string, error) { + return []string{"abcdefghijk0", "abcdefghijk2"}, nil + } + + // Run sync + require.Nil(t, cache.sync()) + + // File 1 should be deleted (orphan) + _, _, err = cache.Read("abcdefghijk1") + require.Error(t, err) + + // Size should be updated + require.Equal(t, int64(10), cache.Size()) +} + +func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) { + mockServer := newMockS3ServerWithModTime(time.Now()) + defer mockServer.Close() + + cache := newTestS3Store(t, mockServer, "my-bucket", "pfx", 10*1024) + + _, err := cache.Write("abcdefghijk0", strings.NewReader("file0")) + require.Nil(t, err) + + // Set the ID provider to return empty (no valid IDs) + cache.localIDs = func() ([]string, error) { + return []string{}, nil + } + + // File was "just created" (mock returns recent time), so it should NOT be deleted + require.Nil(t, cache.sync()) + + // File should still exist + reader, _, err := cache.Read("abcdefghijk0") + require.Nil(t, err) + reader.Close() +} + // --- Helpers --- -func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) Store { +func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store { t.Helper() - // httptest.NewTLSServer URL is like "https://127.0.0.1:PORT" host := strings.TrimPrefix(server.URL, "https://") - s := &s3Store{ - client: &s3.Client{ - AccessKey: "AKID", - SecretKey: "SECRET", - Region: "us-east-1", - Endpoint: host, - Bucket: bucket, - Prefix: prefix, - PathStyle: true, - HTTPClient: server.Client(), - }, - totalSizeLimit: totalSizeLimit, - } - // Compute initial size (should be 0 for fresh mock) - size, err := s.computeSize() + backend := newS3Backend(&s3.Client{ + AccessKey: "AKID", + SecretKey: "SECRET", + Region: "us-east-1", + Endpoint: host, + Bucket: bucket, + Prefix: prefix, + PathStyle: true, + HTTPClient: server.Client(), + }) + cache, err := newStore(backend, totalSizeLimit, nil) require.Nil(t, err) - s.totalSizeCurrent = size - return s + t.Cleanup(func() { cache.Close() }) + return cache } // --- Mock S3 server --- @@ -167,16 +220,22 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body - uploads map[string]map[int][]byte // uploadID -> partNumber -> data - nextID int // counter for generating upload IDs - mu sync.RWMutex + objects map[string][]byte // full key (bucket/key) -> body + uploads map[string]map[int][]byte // uploadID -> partNumber -> data + nextID int // counter for generating upload IDs + lastModTime time.Time // time to return for LastModified in list responses + mu sync.RWMutex } func newMockS3Server() *httptest.Server { + return newMockS3ServerWithModTime(time.Now().Add(-2 * time.Hour)) +} + +func newMockS3ServerWithModTime(modTime time.Time) *httptest.Server { m := &mockS3Server{ - objects: make(map[string][]byte), - uploads: make(map[string]map[int][]byte), + objects: make(map[string][]byte), + uploads: make(map[string]map[int][]byte), + lastModTime: modTime, } return httptest.NewTLSServer(m) } @@ -341,7 +400,11 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket continue // different bucket } if prefix == "" || strings.HasPrefix(objKey, prefix) { - contents = append(contents, s3ListObject{Key: objKey, Size: int64(len(body))}) + contents = append(contents, s3ListObject{ + Key: objKey, + Size: int64(len(body)), + LastModified: m.lastModTime.Format(time.RFC3339), + }) } } m.mu.RUnlock() @@ -362,6 +425,7 @@ type s3ListResponse struct { } type s3ListObject struct { - Key string `xml:"Key"` - Size int64 `xml:"Size"` + Key string `xml:"Key"` + Size int64 `xml:"Size"` + LastModified string `xml:"LastModified"` } diff --git a/message/cache.go b/message/cache.go index 76aba4be..dd4ef0a4 100644 --- a/message/cache.go +++ b/message/cache.go @@ -46,6 +46,7 @@ type queries struct { selectStats string updateStats string updateMessageTime string + selectAttachmentIDs string } // Cache stores published messages @@ -252,18 +253,7 @@ func (c *Cache) MessagesExpired() ([]string, error) { return nil, err } defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return ids, nil + return readStrings(rows) } // Message returns the message with the given ID, or ErrMessageNotFound if not found @@ -319,18 +309,7 @@ func (c *Cache) Topics() ([]string, error) { return nil, err } defer rows.Close() - topics := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - topics = append(topics, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return topics, nil + return readStrings(rows) } // DeleteMessages deletes the messages with the given IDs @@ -358,15 +337,8 @@ func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, return nil, err } defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { + ids, err := readStrings(rows) + if err != nil { return nil, err } rows.Close() // Close rows before executing delete in same transaction @@ -391,6 +363,16 @@ func (c *Cache) ExpireMessages(topics ...string) error { }) } +// AttachmentIDs returns message IDs with active (non-expired, non-deleted) attachments +func (c *Cache) AttachmentIDs() ([]string, error) { + rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentIDs, time.Now().Unix()) + if err != nil { + return nil, err + } + defer rows.Close() + return readStrings(rows) +} + // AttachmentsExpired returns message IDs with expired attachments that have not been deleted func (c *Cache) AttachmentsExpired() ([]string, error) { rows, err := c.db.Query(c.queries.selectAttachmentsExpired, time.Now().Unix()) @@ -398,18 +380,7 @@ func (c *Cache) AttachmentsExpired() ([]string, error) { return nil, err } defer rows.Close() - ids := make([]string, 0) - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - return nil, err - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - return nil, err - } - return ids, nil + return readStrings(rows) } // MarkAttachmentsDeleted marks the attachments for the given message IDs as deleted @@ -590,3 +561,18 @@ func readMessage(rows *sql.Rows) (*model.Message, error) { Encoding: encoding, }, nil } + +func readStrings(rows *sql.Rows) ([]string, error) { + strs := make([]string, 0) + for rows.Next() { + var s string + if err := rows.Scan(&s); err != nil { + return nil, err + } + strs = append(strs, s) + } + if err := rows.Err(); err != nil { + return nil, err + } + return strs, nil +} diff --git a/message/cache_postgres.go b/message/cache_postgres.go index ba162da2..d59b2590 100644 --- a/message/cache_postgres.go +++ b/message/cache_postgres.go @@ -74,6 +74,8 @@ const ( postgresSelectStatsQuery = `SELECT value FROM message_stats WHERE key = 'messages'` postgresUpdateStatsQuery = `UPDATE message_stats SET value = $1 WHERE key = 'messages'` postgresUpdateMessageTimeQuery = `UPDATE message SET time = $1 WHERE mid = $2` + + postgresSelectAttachmentIDsQuery = `SELECT mid FROM message WHERE attachment_expires > $1 AND attachment_deleted = FALSE` ) var postgresQueries = queries{ @@ -100,6 +102,7 @@ var postgresQueries = queries{ selectStats: postgresSelectStatsQuery, updateStats: postgresUpdateStatsQuery, updateMessageTime: postgresUpdateMessageTimeQuery, + selectAttachmentIDs: postgresSelectAttachmentIDsQuery, } // NewPostgresStore creates a new PostgreSQL-backed message cache store using an existing database connection pool. diff --git a/message/cache_sqlite.go b/message/cache_sqlite.go index a36aba0e..6126f1e1 100644 --- a/message/cache_sqlite.go +++ b/message/cache_sqlite.go @@ -77,6 +77,8 @@ const ( sqliteSelectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'` sqliteUpdateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'` sqliteUpdateMessageTimeQuery = `UPDATE messages SET time = ? WHERE mid = ?` + + sqliteSelectAttachmentIDsQuery = `SELECT mid FROM messages WHERE attachment_expires > ? AND attachment_deleted = 0` ) var sqliteQueries = queries{ @@ -103,6 +105,7 @@ var sqliteQueries = queries{ selectStats: sqliteSelectStatsQuery, updateStats: sqliteUpdateStatsQuery, updateMessageTime: sqliteUpdateMessageTimeQuery, + selectAttachmentIDs: sqliteSelectAttachmentIDsQuery, } // NewSQLiteStore creates a SQLite file-backed cache diff --git a/s3/client.go b/s3/client.go index 56f9608e..5ec8caf6 100644 --- a/s3/client.go +++ b/s3/client.go @@ -196,7 +196,15 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK } objects := make([]Object, len(result.Contents)) for i, obj := range result.Contents { - objects[i] = Object(obj) + var lastModified time.Time + if obj.LastModified != "" { + lastModified, _ = time.Parse(time.RFC3339, obj.LastModified) + } + objects[i] = Object{ + Key: obj.Key, + Size: obj.Size, + LastModified: lastModified, + } } return &ListResult{ Objects: objects, diff --git a/s3/client_test.go b/s3/client_test.go index c3a8fe2c..8007601c 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -13,6 +13,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -243,7 +244,7 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket var contents []listObject for _, objKey := range allKeys { body := m.objects[bucketPath+"/"+objKey] - contents = append(contents, listObject{Key: objKey, Size: int64(len(body))}) + contents = append(contents, listObject{Key: objKey, Size: int64(len(body)), LastModified: time.Now().Format(time.RFC3339)}) } m.mu.RUnlock() diff --git a/s3/types.go b/s3/types.go index 201c570b..65615fcd 100644 --- a/s3/types.go +++ b/s3/types.go @@ -1,6 +1,9 @@ package s3 -import "fmt" +import ( + "fmt" + "time" +) // Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string. type Config struct { @@ -15,8 +18,9 @@ type Config struct { // Object represents an S3 object returned by list operations. type Object struct { - Key string - Size int64 + Key string + Size int64 + LastModified time.Time } // ListResult holds the response from a ListObjectsV2 call. @@ -49,8 +53,9 @@ type listObjectsV2Response struct { } type listObject struct { - Key string `xml:"Key"` - Size int64 `xml:"Size"` + Key string `xml:"Key"` + Size int64 `xml:"Size"` + LastModified string `xml:"LastModified"` } // deleteResult is the XML response from S3 DeleteObjects diff --git a/server/server.go b/server/server.go index 0972f00d..b43b71ef 100644 --- a/server/server.go +++ b/server/server.go @@ -65,7 +65,7 @@ type Server struct { userManager *user.Manager // Might be nil! messageCache *message.Cache // Database that stores the messages webPush *webpush.Store // Database that stores web push subscriptions - fileCache attachment.Store // Attachment store (file system or S3) + fileCache *attachment.Store // Attachment store (file system or S3) stripe stripeAPI // Stripe API, can be replaced with a mock priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set @@ -229,7 +229,7 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } - fileCache, err := createAttachmentStore(conf) + fileCache, err := createAttachmentStore(conf, messageCache) if err != nil { return nil, err } @@ -300,11 +300,14 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) { return message.NewMemStore() } -func createAttachmentStore(conf *Config) (attachment.Store, error) { +func createAttachmentStore(conf *Config, messageCache *message.Cache) (*attachment.Store, error) { + idProvider := func() ([]string, error) { + return messageCache.AttachmentIDs() + } if conf.AttachmentS3URL != "" { - return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit) + return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit, idProvider) } else if conf.AttachmentCacheDir != "" { - return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit) + return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, idProvider) } return nil, nil } @@ -429,6 +432,9 @@ func (s *Server) Stop() { if s.smtpServer != nil { s.smtpServer.Close() } + if s.fileCache != nil { + s.fileCache.Close() + } s.closeDatabases() close(s.closeChan) } diff --git a/util/limit.go b/util/limit.go index ad2118c7..9c39d3dc 100644 --- a/util/limit.go +++ b/util/limit.go @@ -152,6 +152,61 @@ func (l *RateLimiter) Reset() { l.value = 0 } +// CountingReader wraps an io.Reader and counts the number of bytes read through it. +type CountingReader struct { + r io.Reader + total int64 +} + +// NewCountingReader creates a new CountingReader +func NewCountingReader(r io.Reader) *CountingReader { + return &CountingReader{r: r} +} + +// Read passes through to the underlying reader and counts the bytes read +func (r *CountingReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + r.total += int64(n) + return +} + +// Total returns the total number of bytes read so far +func (r *CountingReader) Total() int64 { + return r.total +} + +// LimitReader implements an io.Reader that will pass through all Read calls to the underlying +// reader r until any of the limiter's limit is reached, at which point a Read will return ErrLimitReached. +// Each limiter's value is increased after every read based on the number of bytes actually read. +type LimitReader struct { + r io.Reader + limiters []Limiter +} + +// NewLimitReader creates a new LimitReader +func NewLimitReader(r io.Reader, limiters ...Limiter) *LimitReader { + return &LimitReader{ + r: r, + limiters: limiters, + } +} + +// Read passes through all reads to the underlying reader until any of the given limiter's limit is reached +func (r *LimitReader) Read(p []byte) (n int, err error) { + n, err = r.r.Read(p) + if n > 0 { + for i := 0; i < len(r.limiters); i++ { + if !r.limiters[i].AllowN(int64(n)) { + for j := i - 1; j >= 0; j-- { + r.limiters[j].AllowN(-int64(n)) // Revert limiters if not allowed + } + return 0, ErrLimitReached + } + } + } + return +} + // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying // writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached. // Each limiter's value is increased with every write. From a1b403d23cf1a478e52c4c99dd79f9cc92bfa36a Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 19 Mar 2026 21:11:36 -0400 Subject: [PATCH 08/32] Remove s3 config option, reduce size when removing files --- attachment/backend.go | 3 +- attachment/backend_file.go | 41 +++++++++++--------- attachment/backend_s3.go | 49 +++++++++++++----------- attachment/store.go | 69 +++++++++++++++++++++++----------- attachment/store_file_test.go | 4 +- attachment/store_s3_test.go | 5 +-- cmd/serve.go | 9 +---- docs/config.md | 17 ++++----- docs/releases.md | 2 +- s3/client.go | 71 +++++++++++++++++++++++++++++++++++ s3/types.go | 21 +++++++++++ server/config.go | 1 - server/server.go | 8 ++-- server/server.yml | 4 +- 14 files changed, 211 insertions(+), 93 deletions(-) diff --git a/attachment/backend.go b/attachment/backend.go index 8989b890..e95fc91e 100644 --- a/attachment/backend.go +++ b/attachment/backend.go @@ -17,6 +17,7 @@ type object struct { type backend interface { Put(id string, in io.Reader) error Get(id string) (io.ReadCloser, int64, error) - Delete(ids ...string) error List() ([]object, error) + Delete(ids ...string) error + DeleteIncomplete(cutoff time.Time) error } diff --git a/attachment/backend_file.go b/attachment/backend_file.go index b0afb2ca..8aaf20b9 100644 --- a/attachment/backend_file.go +++ b/attachment/backend_file.go @@ -4,6 +4,7 @@ import ( "io" "os" "path/filepath" + "time" "heckel.io/ntfy/v2/log" ) @@ -41,6 +42,26 @@ func (b *fileBackend) Put(id string, in io.Reader) error { return nil } +func (b *fileBackend) List() ([]object, error) { + entries, err := os.ReadDir(b.dir) + if err != nil { + return nil, err + } + objects := make([]object, 0, len(entries)) + for _, e := range entries { + info, err := e.Info() + if err != nil { + return nil, err + } + objects = append(objects, object{ + ID: e.Name(), + Size: info.Size(), + LastModified: info.ModTime(), + }) + } + return objects, nil +} + func (b *fileBackend) Get(id string) (io.ReadCloser, int64, error) { file := filepath.Join(b.dir, id) stat, err := os.Stat(file) @@ -64,22 +85,6 @@ func (b *fileBackend) Delete(ids ...string) error { return nil } -func (b *fileBackend) List() ([]object, error) { - entries, err := os.ReadDir(b.dir) - if err != nil { - return nil, err - } - objects := make([]object, 0, len(entries)) - for _, e := range entries { - info, err := e.Info() - if err != nil { - return nil, err - } - objects = append(objects, object{ - ID: e.Name(), - Size: info.Size(), - LastModified: info.ModTime(), - }) - } - return objects, nil +func (b *fileBackend) DeleteIncomplete(_ time.Time) error { + return nil } diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index 8fcd8ccb..11d2254b 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -4,13 +4,16 @@ import ( "context" "io" "strings" - - "heckel.io/ntfy/v2/s3" + "time" "heckel.io/ntfy/v2/log" + "heckel.io/ntfy/v2/s3" ) -const tagS3Backend = "s3_backend" +const ( + tagS3Backend = "s3_backend" + deleteBatchSize = 1000 +) type s3Backend struct { client *s3.Client @@ -30,24 +33,6 @@ func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) { return b.client.GetObject(context.Background(), id) } -func (b *s3Backend) Delete(ids ...string) error { - // S3 DeleteObjects supports up to 1000 keys per call - for i := 0; i < len(ids); i += 1000 { - end := i + 1000 - if end > len(ids) { - end = len(ids) - } - batch := ids[i:end] - for _, id := range batch { - log.Tag(tagS3Backend).Field("message_id", id).Debug("Deleting attachment from S3") - } - if err := b.client.DeleteObjects(context.Background(), batch); err != nil { - return err - } - } - return nil -} - func (b *s3Backend) List() ([]object, error) { objects, err := b.client.ListAllObjects(context.Background()) if err != nil { @@ -68,3 +53,25 @@ func (b *s3Backend) List() ([]object, error) { } return result, nil } + +func (b *s3Backend) Delete(ids ...string) error { + // S3 DeleteObjects supports up to 1000 keys per call + for i := 0; i < len(ids); i += deleteBatchSize { + end := i + deleteBatchSize + if end > len(ids) { + end = len(ids) + } + batch := ids[i:end] + for _, id := range batch { + log.Tag(tagS3Backend).Field("message_id", id).Debug("Deleting attachment from S3") + } + if err := b.client.DeleteObjects(context.Background(), batch); err != nil { + return err + } + } + return nil +} + +func (b *s3Backend) DeleteIncomplete(cutoff time.Time) error { + return b.client.AbortIncompleteUploads(context.Background(), cutoff) +} diff --git a/attachment/store.go b/attachment/store.go index f66cd35c..78b8a7cc 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -28,21 +28,22 @@ var ( // Store manages attachment storage with shared logic for size tracking, limiting, // ID validation, and background sync to reconcile storage with the database. type Store struct { - backend backend - totalSizeCurrent int64 - totalSizeLimit int64 - localIDs func() ([]string, error) // returns IDs that should exist - closeChan chan struct{} - mu sync.Mutex // Protects totalSizeCurrent + backend backend + limit int64 // Defined limit of the store in bytes + size int64 // Current size of the store in bytes + sizes map[string]int64 // File ID -> size, for subtracting on Remove + localIDs func() ([]string, error) // Returns file IDs that should exist locally, used for sync() + closeChan chan struct{} + mu sync.Mutex // Protects size and sizes } // NewFileStore creates a new file-system backed attachment cache func NewFileStore(dir string, totalSizeLimit int64, localIDsFn func() ([]string, error)) (*Store, error) { - backend, err := newFileBackend(dir) + b, err := newFileBackend(dir) if err != nil { return nil, err } - return newStore(backend, totalSizeLimit, localIDsFn) + return newStore(b, totalSizeLimit, localIDsFn) } // NewS3Store creates a new S3-backed attachment cache. The s3URL must be in the format: @@ -58,10 +59,11 @@ func NewS3Store(s3URL string, totalSizeLimit int64, localIDs func() ([]string, e func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string, error)) (*Store, error) { c := &Store{ - backend: backend, - totalSizeLimit: totalSizeLimit, - localIDs: localIDs, - closeChan: make(chan struct{}), + backend: backend, + limit: totalSizeLimit, + sizes: make(map[string]int64), + localIDs: localIDs, + closeChan: make(chan struct{}), } if localIDs != nil { go c.syncLoop() @@ -85,7 +87,8 @@ func (c *Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, } size := cr.Total() c.mu.Lock() - c.totalSizeCurrent += size + c.size += size + c.sizes[id] = size c.mu.Unlock() return size, nil } @@ -98,15 +101,30 @@ func (c *Store) Read(id string) (io.ReadCloser, int64, error) { return c.backend.Get(id) } -// Remove deletes attachment files by ID. It does NOT recompute the total size; -// the next sync() call will correct it. +// Remove deletes attachment files by ID and subtracts their known sizes from +// the total. Sizes for objects not tracked (e.g. written before this process +// started and before the first sync) are corrected by the next sync() call. func (c *Store) Remove(ids ...string) error { for _, id := range ids { if !fileIDRegex.MatchString(id) { return errInvalidFileID } } - return c.backend.Delete(ids...) + if err := c.backend.Delete(ids...); err != nil { + return err + } + c.mu.Lock() + for _, id := range ids { + if size, ok := c.sizes[id]; ok { + c.size -= size + delete(c.sizes, id) + } + } + if c.size < 0 { + c.size = 0 + } + c.mu.Unlock() + return nil } // sync reconciles the backend storage with the database. It lists all objects, @@ -130,7 +148,8 @@ func (c *Store) sync() error { // than the grace period to account for races, and skipping objects with invalid IDs. cutoff := time.Now().Add(-orphanGracePeriod) var orphanIDs []string - var totalSize int64 + var size int64 + sizes := make(map[string]int64, len(remoteObjects)) for _, obj := range remoteObjects { if !fileIDRegex.MatchString(obj.ID) { continue @@ -138,12 +157,14 @@ func (c *Store) sync() error { if _, ok := localIDMap[obj.ID]; !ok && obj.LastModified.Before(cutoff) { orphanIDs = append(orphanIDs, obj.ID) } else { - totalSize += obj.Size + size += obj.Size + sizes[obj.ID] = obj.Size } } - log.Tag(tagStore).Debug("Sync: cache size updated to %s", util.FormatSizeHuman(totalSize)) + log.Tag(tagStore).Debug("Sync: cache size updated to %s", util.FormatSizeHuman(size)) c.mu.Lock() - c.totalSizeCurrent = totalSize + c.size = size + c.sizes = sizes c.mu.Unlock() // Delete orphaned attachments if len(orphanIDs) > 0 { @@ -152,6 +173,10 @@ func (c *Store) sync() error { return fmt.Errorf("attachment sync: failed to delete orphaned objects: %w", err) } } + // Clean up incomplete uploads (S3 only) + if err := c.backend.DeleteIncomplete(cutoff); err != nil { + log.Tag(tagStore).Err(err).Warn("Sync: failed to abort incomplete uploads") + } return nil } @@ -159,14 +184,14 @@ func (c *Store) sync() error { func (c *Store) Size() int64 { c.mu.Lock() defer c.mu.Unlock() - return c.totalSizeCurrent + return c.size } // Remaining returns the remaining capacity for attachments func (c *Store) Remaining() int64 { c.mu.Lock() defer c.mu.Unlock() - remaining := c.totalSizeLimit - c.totalSizeCurrent + remaining := c.limit - c.size if remaining < 0 { return 0 } diff --git a/attachment/store_file_test.go b/attachment/store_file_test.go index ceac09d7..c65bad92 100644 --- a/attachment/store_file_test.go +++ b/attachment/store_file_test.go @@ -57,8 +57,8 @@ func TestFileStore_Write_Remove_Success(t *testing.T) { require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5")) require.NoFileExists(t, dir+"/abcdefghijk1") require.NoFileExists(t, dir+"/abcdefghijk5") - // Size is not recomputed by Remove; it stays stale until next sync - require.Equal(t, int64(9990), c.Size()) + require.Equal(t, int64(8*999), c.Size()) + require.Equal(t, int64(10240-8*999), c.Remaining()) } func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) { diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index e29f9ed6..7ae122ae 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -42,7 +42,7 @@ func TestS3Store_WriteReadRemove(t *testing.T) { // Remove require.Nil(t, cache.Remove("abcdefghijkl")) - // Size is not recomputed by Remove; stays stale until next sync + require.Equal(t, int64(0), cache.Size()) // Read after remove should fail _, _, err = cache.Read("abcdefghijkl") @@ -107,8 +107,7 @@ func TestS3Store_WriteRemoveMultiple(t *testing.T) { require.Equal(t, int64(500), cache.Size()) require.Nil(t, cache.Remove("abcdefghijk1", "abcdefghijk3")) - // Size not recomputed by Remove - require.Equal(t, int64(500), cache.Size()) + require.Equal(t, int64(300), cache.Size()) } func TestS3Store_ReadNotFound(t *testing.T) { diff --git a/cmd/serve.go b/cmd/serve.go index 26a08c81..52794a07 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -52,8 +52,7 @@ var flagsServe = append( altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-users", Aliases: []string{"auth_users"}, EnvVars: []string{"NTFY_AUTH_USERS"}, Usage: "pre-provisioned declarative users"}), altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}), altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}), - altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}), - altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-s3-url", Aliases: []string{"attachment_s3_url"}, EnvVars: []string{"NTFY_ATTACHMENT_S3_URL"}, Usage: "S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files, or S3 URL (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-expiry-duration", Aliases: []string{"attachment_expiry_duration", "X"}, EnvVars: []string{"NTFY_ATTACHMENT_EXPIRY_DURATION"}, Value: util.FormatDuration(server.DefaultAttachmentExpiryDuration), Usage: "duration after which uploaded attachments will be deleted (e.g. 3h, 20h)"}), @@ -167,7 +166,6 @@ func execServe(c *cli.Context) error { authAccessRaw := c.StringSlice("auth-access") authTokensRaw := c.StringSlice("auth-tokens") attachmentCacheDir := c.String("attachment-cache-dir") - attachmentS3URL := c.String("attachment-s3-url") attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit") attachmentFileSizeLimitStr := c.String("attachment-file-size-limit") attachmentExpiryDurationStr := c.String("attachment-expiry-duration") @@ -316,10 +314,6 @@ func execServe(c *cli.Context) error { return errors.New("if smtp-server-listen is set, smtp-server-domain must also be set") } else if attachmentCacheDir != "" && baseURL == "" { return errors.New("if attachment-cache-dir is set, base-url must also be set") - } else if attachmentS3URL != "" && baseURL == "" { - return errors.New("if attachment-s3-url is set, base-url must also be set") - } else if attachmentS3URL != "" && attachmentCacheDir != "" { - return errors.New("attachment-cache-dir and attachment-s3-url are mutually exclusive") } else if baseURL != "" { u, err := url.Parse(baseURL) if err != nil { @@ -463,7 +457,6 @@ func execServe(c *cli.Context) error { conf.AuthAccess = authAccess conf.AuthTokens = authTokens conf.AttachmentCacheDir = attachmentCacheDir - conf.AttachmentS3URL = attachmentS3URL conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit conf.AttachmentFileSizeLimit = attachmentFileSizeLimit conf.AttachmentExpiryDuration = attachmentExpiryDuration diff --git a/docs/config.md b/docs/config.md index 34484a51..edfa43ff 100644 --- a/docs/config.md +++ b/docs/config.md @@ -490,7 +490,7 @@ Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscri ## Attachments If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable this feature, you have to configure an attachment storage backend and a base URL (`base-url`). Attachments can be stored -either on the local filesystem (`attachment-cache-dir`) or in an S3-compatible object store (`attachment-s3-url`). +either on the local filesystem or in an S3-compatible object store, both using the `attachment-cache-dir` option. Once configured, you can upload attachments via PUT. By default, attachments are stored **for only 3 hours**. The main reason for this is to avoid legal issues @@ -498,8 +498,7 @@ and such when hosting user controlled content. Typically, this is more than enou feature) to download the file. The following config options are relevant to attachments: * `base-url` is the root URL for the ntfy server; this is needed for the generated attachment URLs -* `attachment-cache-dir` is the cache directory for attached files (mutually exclusive with `attachment-s3-url`) -* `attachment-s3-url` is the S3 URL for attachment storage (mutually exclusive with `attachment-cache-dir`) +* `attachment-cache-dir` is the cache directory for attached files, or an S3 URL for object storage * `attachment-total-size-limit` is the size limit of the attachment storage (default: 5G) * `attachment-file-size-limit` is the per-file attachment size limit (e.g. 300k, 2M, 100M, default: 15M) * `attachment-expiry-duration` is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h, default: 3h) @@ -528,7 +527,7 @@ Here's an example config using the local filesystem for attachment storage: As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3, MinIO, DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage. -The `attachment-s3-url` option uses the following format: +To use S3, set `attachment-cache-dir` to an S3 URL with the following format: ``` s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] @@ -539,13 +538,13 @@ When `endpoint` is specified, path-style addressing is enabled automatically (us === "/etc/ntfy/server.yml (AWS S3)" ``` yaml base-url: "https://ntfy.sh" - attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1" + attachment-cache-dir: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1" ``` === "/etc/ntfy/server.yml (MinIO/custom endpoint)" ``` yaml base-url: "https://ntfy.sh" - attachment-s3-url: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com" + attachment-cache-dir: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com" ``` Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit` @@ -2143,8 +2142,7 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`). | `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) | | `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) | | `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header | -| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory* | - | Cache directory for attached files. Mutually exclusive with `attachment-s3-url`. | -| `attachment-s3-url` | `NTFY_ATTACHMENT_S3_URL` | *URL* | - | S3 URL for attachment storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). Mutually exclusive with `attachment-cache-dir`. | +| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory or S3 URL* | - | Cache directory for attached files, or S3 URL for object storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). | | `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. | | `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. | | `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. | @@ -2246,8 +2244,7 @@ OPTIONS: --auth-file value, --auth_file value, -H value auth database file used for access control [$NTFY_AUTH_FILE] --auth-startup-queries value, --auth_startup_queries value queries run when the auth database is initialized [$NTFY_AUTH_STARTUP_QUERIES] --auth-default-access value, --auth_default_access value, -p value default permissions if no matching entries in the auth database are found (default: "read-write") [$NTFY_AUTH_DEFAULT_ACCESS] - --attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files [$NTFY_ATTACHMENT_CACHE_DIR] - --attachment-s3-url value, --attachment_s3_url value S3 URL for attachment storage (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_S3_URL] + --attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files, or S3 URL (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_CACHE_DIR] --attachment-total-size-limit value, --attachment_total_size_limit value, -A value limit of the on-disk attachment cache (default: "5G") [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT] --attachment-file-size-limit value, --attachment_file_size_limit value, -Y value per-file attachment size limit (e.g. 300k, 2M, 100M) (default: "15M") [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT] --attachment-expiry-duration value, --attachment_expiry_duration value, -X value duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: "3h") [$NTFY_ATTACHMENT_EXPIRY_DURATION] diff --git a/docs/releases.md b/docs/releases.md index dbf9bd41..b16608c6 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1802,7 +1802,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release **Features:** -* Add S3-compatible object storage as an alternative attachment backend via `attachment-s3-url` config option +* Add S3-compatible object storage as an alternative attachment backend via `attachment-cache-dir` config option **Bug fixes + maintenance:** diff --git a/s3/client.go b/s3/client.go index 5ec8caf6..41d940f3 100644 --- a/s3/client.go +++ b/s3/client.go @@ -232,6 +232,77 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) } +// ListMultipartUploads returns in-progress multipart uploads for the client's prefix. +// It paginates automatically, stopping after 10,000 pages as a safety valve. +func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) { + var all []MultipartUpload + var keyMarker, uploadIDMarker string + for page := 0; page < maxPages; page++ { + query := url.Values{"uploads": {""}} + if prefix := c.prefixForList(); prefix != "" { + query.Set("prefix", prefix) + } + if keyMarker != "" { + query.Set("key-marker", keyMarker) + query.Set("upload-id-marker", uploadIDMarker) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil) + if err != nil { + return nil, fmt.Errorf("s3: ListMultipartUploads request: %w", err) + } + c.signV4(req, emptyPayloadHash) + resp, err := c.httpClient().Do(req) + if err != nil { + return nil, fmt.Errorf("s3: ListMultipartUploads: %w", err) + } + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("s3: ListMultipartUploads read: %w", err) + } + if !isHTTPSuccess(resp) { + return nil, parseErrorFromBytes(resp.StatusCode, respBody) + } + var result listMultipartUploadsResult + if err := xml.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err) + } + for _, u := range result.Uploads { + var initiated time.Time + if u.Initiated != "" { + initiated, _ = time.Parse(time.RFC3339, u.Initiated) + } + all = append(all, MultipartUpload{ + Key: u.Key, + UploadID: u.UploadID, + Initiated: initiated, + }) + } + if !result.IsTruncated { + return all, nil + } + keyMarker = result.NextKeyMarker + uploadIDMarker = result.NextUploadIDMarker + } + return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages) +} + +// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated +// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads. +func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error { + uploads, err := c.ListMultipartUploads(ctx) + if err != nil { + return err + } + for _, u := range uploads { + if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) { + log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated) + c.abortMultipartUpload(ctx, u.Key, u.UploadID) + } + } + return nil +} + // putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error { fullKey := c.objectKey(key) diff --git a/s3/types.go b/s3/types.go index 65615fcd..f78c4a8b 100644 --- a/s3/types.go +++ b/s3/types.go @@ -69,6 +69,27 @@ type deleteError struct { Message string `xml:"Message"` } +// MultipartUpload represents an in-progress multipart upload returned by ListMultipartUploads. +type MultipartUpload struct { + Key string + UploadID string + Initiated time.Time +} + +// listMultipartUploadsResult is the XML response from S3 ListMultipartUploads +type listMultipartUploadsResult struct { + Uploads []listUpload `xml:"Upload"` + IsTruncated bool `xml:"IsTruncated"` + NextKeyMarker string `xml:"NextKeyMarker"` + NextUploadIDMarker string `xml:"NextUploadIdMarker"` +} + +type listUpload struct { + Key string `xml:"Key"` + UploadID string `xml:"UploadId"` + Initiated string `xml:"Initiated"` +} + // initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload type initiateMultipartUploadResult struct { UploadID string `xml:"UploadId"` diff --git a/server/config.go b/server/config.go index 97f72a1c..8ead312c 100644 --- a/server/config.go +++ b/server/config.go @@ -112,7 +112,6 @@ type Config struct { AuthBcryptCost int AuthStatsQueueWriterInterval time.Duration AttachmentCacheDir string - AttachmentS3URL string AttachmentTotalSizeLimit int64 AttachmentFileSizeLimit int64 AttachmentExpiryDuration time.Duration diff --git a/server/server.go b/server/server.go index b43b71ef..99a61906 100644 --- a/server/server.go +++ b/server/server.go @@ -301,13 +301,13 @@ func createMessageCache(conf *Config, pool *db.DB) (*message.Cache, error) { } func createAttachmentStore(conf *Config, messageCache *message.Cache) (*attachment.Store, error) { - idProvider := func() ([]string, error) { + attachmentIDs := func() ([]string, error) { return messageCache.AttachmentIDs() } - if conf.AttachmentS3URL != "" { - return attachment.NewS3Store(conf.AttachmentS3URL, conf.AttachmentTotalSizeLimit, idProvider) + if strings.HasPrefix(conf.AttachmentCacheDir, "s3://") { + return attachment.NewS3Store(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, attachmentIDs) } else if conf.AttachmentCacheDir != "" { - return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, idProvider) + return attachment.NewFileStore(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit, attachmentIDs) } return nil, nil } diff --git a/server/server.yml b/server/server.yml index e6f7afee..9dc92968 100644 --- a/server/server.yml +++ b/server/server.yml @@ -153,13 +153,13 @@ # If enabled, clients can attach files to notifications as attachments. Minimum settings to enable attachments # are "attachment-cache-dir" and "base-url". # -# - attachment-cache-dir is the cache directory for attached files +# - attachment-cache-dir is the cache directory for attached files, or an S3 URL for object storage +# e.g. /var/cache/ntfy/attachments, or s3://ACCESS_KEY:SECRET_KEY@bucket/prefix?region=us-east-1&endpoint=https://... # - attachment-total-size-limit is the limit of the on-disk attachment cache directory (total size) # - attachment-file-size-limit is the per-file attachment size limit (e.g. 300k, 2M, 100M) # - attachment-expiry-duration is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h) # # attachment-cache-dir: -# attachment-s3-url: "s3://ACCESS_KEY:SECRET_KEY@bucket/prefix?region=us-east-1" # attachment-total-size-limit: "5G" # attachment-file-size-limit: "15M" # attachment-expiry-duration: "3h" From d86e20173cc040596f02598ee104e8ab4f498a7a Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 19 Mar 2026 21:46:52 -0400 Subject: [PATCH 09/32] Move stuff around --- attachment/backend_s3.go | 8 +- attachment/store_s3_test.go | 4 +- s3/client.go | 407 +++++++----------------------------- s3/client_auth.go | 68 ++++++ s3/client_multipart.go | 188 +++++++++++++++++ s3/client_test.go | 61 +++--- s3/types.go | 44 ++-- s3/util_test.go | 12 +- 8 files changed, 396 insertions(+), 396 deletions(-) create mode 100644 s3/client_auth.go create mode 100644 s3/client_multipart.go diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index 11d2254b..6603bb91 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -3,7 +3,6 @@ package attachment import ( "context" "io" - "strings" "time" "heckel.io/ntfy/v2/log" @@ -38,15 +37,10 @@ func (b *s3Backend) List() ([]object, error) { if err != nil { return nil, err } - prefix := b.client.Prefix result := make([]object, 0, len(objects)) for _, obj := range objects { - id := obj.Key - if prefix != "" { - id = strings.TrimPrefix(id, prefix+"/") - } result = append(result, object{ - ID: id, + ID: obj.Key, Size: obj.Size, LastModified: obj.LastModified, }) diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 7ae122ae..3ad5a93c 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -197,7 +197,7 @@ func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) { func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store { t.Helper() host := strings.TrimPrefix(server.URL, "https://") - backend := newS3Backend(&s3.Client{ + backend := newS3Backend(s3.New(&s3.Config{ AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", @@ -206,7 +206,7 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string Prefix: prefix, PathStyle: true, HTTPClient: server.Client(), - }) + })) cache, err := newStore(backend, totalSizeLimit, nil) require.Nil(t, err) t.Cleanup(func() { cache.Close() }) diff --git a/s3/client.go b/s3/client.go index 41d940f3..5c43af05 100644 --- a/s3/client.go +++ b/s3/client.go @@ -8,14 +8,12 @@ import ( "context" "crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers "encoding/base64" - "encoding/hex" "encoding/xml" "errors" "fmt" "io" "net/http" "net/url" - "sort" "strconv" "strings" "time" @@ -33,26 +31,19 @@ const ( // // Fields must not be modified after the Client is passed to any method or goroutine. type Client struct { - AccessKey string // AWS access key ID - SecretKey string // AWS secret access key - Region string // e.g. "us-east-1" - Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com" - Bucket string // S3 bucket name - Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically - PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style - HTTPClient *http.Client // if nil, http.DefaultClient is used + config *Config + http *http.Client } // New creates a new S3 client from the given Config. func New(config *Config) *Client { + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } return &Client{ - AccessKey: config.AccessKey, - SecretKey: config.SecretKey, - Region: config.Region, - Endpoint: config.Endpoint, - Bucket: config.Bucket, - Prefix: config.Prefix, - PathStyle: config.PathStyle, + config: config, + http: httpClient, } } @@ -87,7 +78,7 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int6 return nil, 0, fmt.Errorf("s3: GetObject request: %w", err) } c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) + resp, err := c.http.Do(req) if err != nil { return nil, 0, fmt.Errorf("s3: GetObject: %w", err) } @@ -116,35 +107,18 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { } body.WriteString("") bodyBytes := body.Bytes() - payloadHash := sha256Hex(bodyBytes) // Content-MD5 is required by the S3 protocol for DeleteObjects requests. md5Sum := md5.Sum(bodyBytes) //nolint:gosec contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:]) - reqURL := c.bucketURL() + "?delete=" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes)) + respBody, err := c.doWithBodyAndHeaders(ctx, http.MethodPost, c.config.BucketURL()+"?delete=", bodyBytes, + map[string]string{"Content-MD5": contentMD5}, "DeleteObjects") if err != nil { - return fmt.Errorf("s3: DeleteObjects request: %w", err) - } - req.ContentLength = int64(len(bodyBytes)) - req.Header.Set("Content-Type", "application/xml") - req.Header.Set("Content-MD5", contentMD5) - c.signV4(req, payloadHash) - resp, err := c.httpClient().Do(req) - if err != nil { - return fmt.Errorf("s3: DeleteObjects: %w", err) - } - defer resp.Body.Close() - if !isHTTPSuccess(resp) { - return parseError(resp) + return err } // S3 may return HTTP 200 with per-key errors in the response body - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - if err != nil { - return fmt.Errorf("s3: DeleteObjects read response: %w", err) - } var result deleteResult if err := xml.Unmarshal(respBody, &result); err != nil { return nil // If we can't parse, assume success (Quiet mode returns empty body on success) @@ -159,9 +133,9 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { return nil } -// ListObjects performs a single ListObjectsV2 request using the client's configured prefix. +// listObjects performs a single ListObjectsV2 request using the client's configured prefix. // Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000). -func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) { +func (c *Client) listObjects(ctx context.Context, continuationToken string, maxKeys int) (*listResult, error) { log.Tag(tagS3Client).Debug("ListObjects continuation=%s maxKeys=%d", continuationToken, maxKeys) query := url.Values{"list-type": {"2"}} if prefix := c.prefixForList(); prefix != "" { @@ -173,22 +147,9 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK if maxKeys > 0 { query.Set("max-keys", strconv.Itoa(maxKeys)) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil) + respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListObjects") if err != nil { - return nil, fmt.Errorf("s3: ListObjects request: %w", err) - } - c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) - if err != nil { - return nil, fmt.Errorf("s3: ListObjects: %w", err) - } - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("s3: ListObjects read: %w", err) - } - if !isHTTPSuccess(resp) { - return nil, parseErrorFromBytes(resp.StatusCode, respBody) + return nil, err } var result listObjectsV2Response if err := xml.Unmarshal(respBody, &result); err != nil { @@ -206,7 +167,7 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK LastModified: lastModified, } } - return &ListResult{ + return &listResult{ Objects: objects, IsTruncated: result.IsTruncated, NextContinuationToken: result.NextContinuationToken, @@ -214,16 +175,21 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK } // ListAllObjects returns all objects under the client's configured prefix by paginating through -// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve. +// ListObjectsV2 results automatically. Keys in the returned objects have the prefix stripped, +// so they match the keys used with PutObject/GetObject/DeleteObjects. It stops after 10,000 +// pages as a safety valve. func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { var all []Object var token string for page := 0; page < maxPages; page++ { - result, err := c.ListObjects(ctx, token, 0) + result, err := c.listObjects(ctx, token, 0) if err != nil { return nil, err } - all = append(all, result.Objects...) + for _, obj := range result.Objects { + obj.Key = c.stripPrefix(obj.Key) + all = append(all, obj) + } if !result.IsTruncated { return all, nil } @@ -232,77 +198,6 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) } -// ListMultipartUploads returns in-progress multipart uploads for the client's prefix. -// It paginates automatically, stopping after 10,000 pages as a safety valve. -func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) { - var all []MultipartUpload - var keyMarker, uploadIDMarker string - for page := 0; page < maxPages; page++ { - query := url.Values{"uploads": {""}} - if prefix := c.prefixForList(); prefix != "" { - query.Set("prefix", prefix) - } - if keyMarker != "" { - query.Set("key-marker", keyMarker) - query.Set("upload-id-marker", uploadIDMarker) - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil) - if err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads request: %w", err) - } - c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) - if err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads: %w", err) - } - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads read: %w", err) - } - if !isHTTPSuccess(resp) { - return nil, parseErrorFromBytes(resp.StatusCode, respBody) - } - var result listMultipartUploadsResult - if err := xml.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err) - } - for _, u := range result.Uploads { - var initiated time.Time - if u.Initiated != "" { - initiated, _ = time.Parse(time.RFC3339, u.Initiated) - } - all = append(all, MultipartUpload{ - Key: u.Key, - UploadID: u.UploadID, - Initiated: initiated, - }) - } - if !result.IsTruncated { - return all, nil - } - keyMarker = result.NextKeyMarker - uploadIDMarker = result.NextUploadIDMarker - } - return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages) -} - -// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated -// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads. -func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error { - uploads, err := c.ListMultipartUploads(ctx) - if err != nil { - return err - } - for _, u := range uploads { - if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) { - log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated) - c.abortMultipartUpload(ctx, u.Key, u.UploadID) - } - } - return nil -} - // putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error { fullKey := c.objectKey(key) @@ -312,235 +207,80 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size } req.ContentLength = size c.signV4(req, unsignedPayload) - resp, err := c.httpClient().Do(req) + resp, err := c.http.Do(req) if err != nil { return fmt.Errorf("s3: PutObject: %w", err) } - defer resp.Body.Close() + resp.Body.Close() if !isHTTPSuccess(resp) { return parseError(resp) } return nil } -// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize -// chunks, uploading each as a separate part. This allows uploading without knowing the total -// body size in advance. -func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error { - fullKey := c.objectKey(key) - - // Step 1: Initiate multipart upload - uploadID, err := c.initiateMultipartUpload(ctx, fullKey) +// do creates a request, signs it with an empty payload, executes it, reads the response body, +// and checks for errors. It is used for bodiless GET/POST requests. +func (c *Client) do(ctx context.Context, method, reqURL string, body io.Reader, op string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, method, reqURL, body) if err != nil { - return err + return nil, fmt.Errorf("s3: %s request: %w", op, err) } - - // Step 2: Upload parts - var parts []completedPart - buf := make([]byte, partSize) - partNumber := 1 - for { - n, err := io.ReadFull(body, buf) - if n > 0 { - etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n]) - if uploadErr != nil { - c.abortMultipartUpload(ctx, fullKey, uploadID) - return uploadErr - } - parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag}) - partNumber++ - } - if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { - break - } - if err != nil { - c.abortMultipartUpload(ctx, fullKey, uploadID) - return fmt.Errorf("s3: PutObject read: %w", err) - } + if body == nil { + req.ContentLength = 0 } - - // Step 3: Complete multipart upload - return c.completeMultipartUpload(ctx, fullKey, uploadID, parts) -} - -// initiateMultipartUpload starts a new multipart upload and returns the upload ID. -func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) { - reqURL := c.objectURL(fullKey) + "?uploads" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil) - if err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err) - } - req.ContentLength = 0 c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) + resp, err := c.http.Do(req) if err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err) - } - defer resp.Body.Close() - if !isHTTPSuccess(resp) { - return "", parseError(resp) + return nil, fmt.Errorf("s3: %s: %w", op, err) } respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + resp.Body.Close() if err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err) + return nil, fmt.Errorf("s3: %s read: %w", op, err) } - var result initiateMultipartUploadResult - if err := xml.Unmarshal(respBody, &result); err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) - } - log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID) - return result.UploadID, nil -} - -// uploadPart uploads a single part of a multipart upload and returns the ETag. -func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) { - log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data)) - reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID)) - req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) - if err != nil { - return "", fmt.Errorf("s3: UploadPart request: %w", err) - } - req.ContentLength = int64(len(data)) - c.signV4(req, unsignedPayload) - resp, err := c.httpClient().Do(req) - if err != nil { - return "", fmt.Errorf("s3: UploadPart: %w", err) - } - defer resp.Body.Close() if !isHTTPSuccess(resp) { - return "", parseError(resp) + return nil, parseErrorFromBytes(resp.StatusCode, respBody) } - etag := resp.Header.Get("ETag") - return etag, nil + return respBody, nil } -// completeMultipartUpload finalizes a multipart upload with the given parts. -func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error { - log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts)) - var body bytes.Buffer - body.WriteString("") - for _, p := range parts { - fmt.Fprintf(&body, "%d%s", p.PartNumber, p.ETag) - } - body.WriteString("") - bodyBytes := body.Bytes() - payloadHash := sha256Hex(bodyBytes) +// doWithBody is like do, but sends a body with a computed SHA-256 payload hash and Content-Type: application/xml. +func (c *Client) doWithBody(ctx context.Context, method, reqURL string, bodyBytes []byte, op string) ([]byte, error) { + return c.doWithBodyAndHeaders(ctx, method, reqURL, bodyBytes, nil, op) +} - reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes)) +// doWithBodyAndHeaders is like doWithBody, but allows setting additional headers (e.g. Content-MD5). +func (c *Client) doWithBodyAndHeaders(ctx context.Context, method, reqURL string, bodyBytes []byte, headers map[string]string, op string) ([]byte, error) { + payloadHash := sha256Hex(bodyBytes) + req, err := http.NewRequestWithContext(ctx, method, reqURL, bytes.NewReader(bodyBytes)) if err != nil { - return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err) + return nil, fmt.Errorf("s3: %s request: %w", op, err) } req.ContentLength = int64(len(bodyBytes)) req.Header.Set("Content-Type", "application/xml") + for k, v := range headers { + req.Header.Set(k, v) + } c.signV4(req, payloadHash) - resp, err := c.httpClient().Do(req) + resp, err := c.http.Do(req) if err != nil { - return fmt.Errorf("s3: CompleteMultipartUpload: %w", err) + return nil, fmt.Errorf("s3: %s: %w", op, err) } - defer resp.Body.Close() - if !isHTTPSuccess(resp) { - return parseError(resp) - } - // Read response body to check for errors (S3 can return 200 with an error body) respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - if err != nil { - return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err) - } - // Check if the response contains an error - var errResp ErrorResponse - if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" { - errResp.StatusCode = resp.StatusCode - return &errResp - } - return nil -} - -// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. -func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) { - log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID) - reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) - if err != nil { - return - } - c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) - if err != nil { - return - } resp.Body.Close() -} - -// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 -// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. -func (c *Client) signV4(req *http.Request, payloadHash string) { - now := time.Now().UTC() - datestamp := now.Format("20060102") - amzDate := now.Format("20060102T150405Z") - - // Required headers - req.Header.Set("Host", c.hostHeader()) - req.Header.Set("X-Amz-Date", amzDate) - req.Header.Set("X-Amz-Content-Sha256", payloadHash) - - // Canonical headers (all headers we set, sorted by lowercase key) - signedKeys := make([]string, 0, len(req.Header)) - canonHeaders := make(map[string]string, len(req.Header)) - for k := range req.Header { - lk := strings.ToLower(k) - signedKeys = append(signedKeys, lk) - canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k)) + if err != nil { + return nil, fmt.Errorf("s3: %s read: %w", op, err) } - sort.Strings(signedKeys) - signedHeadersStr := strings.Join(signedKeys, ";") - var chBuf strings.Builder - for _, k := range signedKeys { - chBuf.WriteString(k) - chBuf.WriteByte(':') - chBuf.WriteString(canonHeaders[k]) - chBuf.WriteByte('\n') + if !isHTTPSuccess(resp) { + return nil, parseErrorFromBytes(resp.StatusCode, respBody) } - - // Canonical request - canonicalRequest := strings.Join([]string{ - req.Method, - canonicalURI(req.URL), - canonicalQueryString(req.URL.Query()), - chBuf.String(), - signedHeadersStr, - payloadHash, - }, "\n") - - // String to sign - credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request" - stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest)) - - // Signing key - signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256( - []byte("AWS4"+c.SecretKey), []byte(datestamp)), - []byte(c.Region)), - []byte("s3")), - []byte("aws4_request")) - - signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) - req.Header.Set("Authorization", fmt.Sprintf( - "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", - c.AccessKey, credentialScope, signedHeadersStr, signature, - )) -} - -func (c *Client) httpClient() *http.Client { - if c.HTTPClient != nil { - return c.HTTPClient - } - return http.DefaultClient + return respBody, nil } // objectKey prepends the configured prefix to the given key. func (c *Client) objectKey(key string) string { - if c.Prefix != "" { - return c.Prefix + "/" + key + if c.config.Prefix != "" { + return c.config.Prefix + "/" + key } return key } @@ -548,18 +288,19 @@ func (c *Client) objectKey(key string) string { // prefixForList returns the prefix to use in ListObjectsV2 requests, // with a trailing slash so that only objects under the prefix directory are returned. func (c *Client) prefixForList() string { - if c.Prefix != "" { - return c.Prefix + "/" + if c.config.Prefix != "" { + return c.config.Prefix + "/" } return "" } -// bucketURL returns the base URL for bucket-level operations. -func (c *Client) bucketURL() string { - if c.PathStyle { - return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket) +// stripPrefix removes the configured prefix from a key returned by ListObjectsV2, +// so keys match what was passed to PutObject/GetObject/DeleteObjects. +func (c *Client) stripPrefix(key string) string { + if c.config.Prefix != "" { + return strings.TrimPrefix(key, c.config.Prefix+"/") } - return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint) + return key } // objectURL returns the full URL for an object (key should already include the prefix). @@ -569,13 +310,5 @@ func (c *Client) objectURL(key string) string { for i, seg := range segments { segments[i] = uriEncode(seg) } - return c.bucketURL() + "/" + strings.Join(segments, "/") -} - -// hostHeader returns the value for the Host header. -func (c *Client) hostHeader() string { - if c.PathStyle { - return c.Endpoint - } - return c.Bucket + "." + c.Endpoint + return c.config.BucketURL() + "/" + strings.Join(segments, "/") } diff --git a/s3/client_auth.go b/s3/client_auth.go new file mode 100644 index 00000000..ede971f3 --- /dev/null +++ b/s3/client_auth.go @@ -0,0 +1,68 @@ +package s3 + +import ( + "encoding/hex" + "fmt" + "net/http" + "sort" + "strings" + "time" +) + +// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 +// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. +func (c *Client) signV4(req *http.Request, payloadHash string) { + now := time.Now().UTC() + datestamp := now.Format("20060102") + amzDate := now.Format("20060102T150405Z") + + // Required headers + req.Header.Set("Host", c.config.HostHeader()) + req.Header.Set("X-Amz-Date", amzDate) + req.Header.Set("X-Amz-Content-Sha256", payloadHash) + + // Canonical headers (all headers we set, sorted by lowercase key) + signedKeys := make([]string, 0, len(req.Header)) + canonHeaders := make(map[string]string, len(req.Header)) + for k := range req.Header { + lk := strings.ToLower(k) + signedKeys = append(signedKeys, lk) + canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k)) + } + sort.Strings(signedKeys) + signedHeadersStr := strings.Join(signedKeys, ";") + var chBuf strings.Builder + for _, k := range signedKeys { + chBuf.WriteString(k) + chBuf.WriteByte(':') + chBuf.WriteString(canonHeaders[k]) + chBuf.WriteByte('\n') + } + + // Canonical request + canonicalRequest := strings.Join([]string{ + req.Method, + canonicalURI(req.URL), + canonicalQueryString(req.URL.Query()), + chBuf.String(), + signedHeadersStr, + payloadHash, + }, "\n") + + // String to sign + credentialScope := datestamp + "/" + c.config.Region + "/s3/aws4_request" + stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest)) + + // Signing key + signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256( + []byte("AWS4"+c.config.SecretKey), []byte(datestamp)), + []byte(c.config.Region)), + []byte("s3")), + []byte("aws4_request")) + + signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) + req.Header.Set("Authorization", fmt.Sprintf( + "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", + c.config.AccessKey, credentialScope, signedHeadersStr, signature, + )) +} diff --git a/s3/client_multipart.go b/s3/client_multipart.go new file mode 100644 index 00000000..61d206b6 --- /dev/null +++ b/s3/client_multipart.go @@ -0,0 +1,188 @@ +package s3 + +import ( + "bytes" + "context" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "heckel.io/ntfy/v2/log" +) + +// ListMultipartUploads returns in-progress multipart uploads for the client's prefix. +// It paginates automatically, stopping after 10,000 pages as a safety valve. +func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) { + var all []MultipartUpload + var keyMarker, uploadIDMarker string + for page := 0; page < maxPages; page++ { + query := url.Values{"uploads": {""}} + if prefix := c.prefixForList(); prefix != "" { + query.Set("prefix", prefix) + } + if keyMarker != "" { + query.Set("key-marker", keyMarker) + query.Set("upload-id-marker", uploadIDMarker) + } + respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListMultipartUploads") + if err != nil { + return nil, err + } + var result listMultipartUploadsResult + if err := xml.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err) + } + for _, u := range result.Uploads { + var initiated time.Time + if u.Initiated != "" { + initiated, _ = time.Parse(time.RFC3339, u.Initiated) + } + all = append(all, MultipartUpload{ + Key: u.Key, + UploadID: u.UploadID, + Initiated: initiated, + }) + } + if !result.IsTruncated { + return all, nil + } + keyMarker = result.NextKeyMarker + uploadIDMarker = result.NextUploadIDMarker + } + return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages) +} + +// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated +// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads. +func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error { + uploads, err := c.ListMultipartUploads(ctx) + if err != nil { + return err + } + for _, u := range uploads { + if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) { + log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated) + c.abortMultipartUpload(ctx, u.Key, u.UploadID) + } + } + return nil +} + +// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize +// chunks, uploading each as a separate part. This allows uploading without knowing the total +// body size in advance. +func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error { + fullKey := c.objectKey(key) + + // Step 1: Initiate multipart upload + uploadID, err := c.initiateMultipartUpload(ctx, fullKey) + if err != nil { + return err + } + + // Step 2: Upload parts + var parts []completedPart + buf := make([]byte, partSize) + partNumber := 1 + for { + n, err := io.ReadFull(body, buf) + if n > 0 { + etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n]) + if uploadErr != nil { + c.abortMultipartUpload(ctx, fullKey, uploadID) + return uploadErr + } + parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag}) + partNumber++ + } + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + if err != nil { + c.abortMultipartUpload(ctx, fullKey, uploadID) + return fmt.Errorf("s3: PutObject read: %w", err) + } + } + + // Step 3: Complete multipart upload + return c.completeMultipartUpload(ctx, fullKey, uploadID, parts) +} + +// initiateMultipartUpload starts a new multipart upload and returns the upload ID. +func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) { + respBody, err := c.do(ctx, http.MethodPost, c.objectURL(fullKey)+"?uploads", nil, "InitiateMultipartUpload") + if err != nil { + return "", err + } + var result initiateMultipartUploadResult + if err := xml.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) + } + log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID) + return result.UploadID, nil +} + +// uploadPart uploads a single part of a multipart upload and returns the ETag. +func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) { + log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data)) + reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) + if err != nil { + return "", fmt.Errorf("s3: UploadPart request: %w", err) + } + req.ContentLength = int64(len(data)) + c.signV4(req, unsignedPayload) + resp, err := c.http.Do(req) + if err != nil { + return "", fmt.Errorf("s3: UploadPart: %w", err) + } + defer resp.Body.Close() + if !isHTTPSuccess(resp) { + return "", parseError(resp) + } + etag := resp.Header.Get("ETag") + return etag, nil +} + +// completeMultipartUpload finalizes a multipart upload with the given parts. +func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error { + log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts)) + var body bytes.Buffer + body.WriteString("") + for _, p := range parts { + fmt.Fprintf(&body, "%d%s", p.PartNumber, p.ETag) + } + body.WriteString("") + respBody, err := c.doWithBody(ctx, http.MethodPost, + fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)), + body.Bytes(), "CompleteMultipartUpload") + if err != nil { + return err + } + // Check if the response contains an error (S3 can return 200 with an error body) + var errResp ErrorResponse + if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" { + return &errResp + } + return nil +} + +// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. +func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) { + log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID) + reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) + if err != nil { + return + } + c.signV4(req, emptyPayloadHash) + resp, err := c.http.Do(req) + if err != nil { + return + } + resp.Body.Close() +} diff --git a/s3/client_test.go b/s3/client_test.go index 8007601c..447f36d0 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -269,7 +269,7 @@ func (m *mockS3Server) objectCount() int { func newTestClient(server *httptest.Server, bucket, prefix string) *Client { // httptest.NewTLSServer URL is like "https://127.0.0.1:PORT" host := strings.TrimPrefix(server.URL, "https://") - return &Client{ + return New(&Config{ AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", @@ -278,7 +278,7 @@ func newTestClient(server *httptest.Server, bucket, prefix string) *Client { Prefix: prefix, PathStyle: true, HTTPClient: server.Client(), - } + }) } // --- URL parsing tests --- @@ -363,49 +363,49 @@ func TestParseURL_EmptyBucket(t *testing.T) { // --- Unit tests: URL construction --- -func TestClient_BucketURL_PathStyle(t *testing.T) { - c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} - require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL()) +func TestConfig_BucketURL_PathStyle(t *testing.T) { + c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + require.Equal(t, "https://s3.example.com/my-bucket", c.BucketURL()) } -func TestClient_BucketURL_VirtualHosted(t *testing.T) { - c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} - require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL()) +func TestConfig_BucketURL_VirtualHosted(t *testing.T) { + c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.BucketURL()) } func TestClient_ObjectURL_PathStyle(t *testing.T) { - c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + c := &Client{config: &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}} require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj")) } func TestClient_ObjectURL_VirtualHosted(t *testing.T) { - c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + c := &Client{config: &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}} require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj")) } -func TestClient_HostHeader_PathStyle(t *testing.T) { - c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} - require.Equal(t, "s3.example.com", c.hostHeader()) +func TestConfig_HostHeader_PathStyle(t *testing.T) { + c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + require.Equal(t, "s3.example.com", c.HostHeader()) } -func TestClient_HostHeader_VirtualHosted(t *testing.T) { - c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} - require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader()) +func TestConfig_HostHeader_VirtualHosted(t *testing.T) { + c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.HostHeader()) } func TestClient_ObjectKey(t *testing.T) { - c := &Client{Prefix: "attachments"} + c := &Client{config: &Config{Prefix: "attachments"}} require.Equal(t, "attachments/file123", c.objectKey("file123")) - c2 := &Client{Prefix: ""} + c2 := &Client{config: &Config{Prefix: ""}} require.Equal(t, "file123", c2.objectKey("file123")) } func TestClient_PrefixForList(t *testing.T) { - c := &Client{Prefix: "attachments"} + c := &Client{config: &Config{Prefix: "attachments"}} require.Equal(t, "attachments/", c.prefixForList()) - c2 := &Client{Prefix: ""} + c2 := &Client{config: &Config{Prefix: ""}} require.Equal(t, "", c2.prefixForList()) } @@ -512,13 +512,13 @@ func TestClient_ListObjects(t *testing.T) { require.Nil(t, err) // List with prefix client: should only see 3 - result, err := client.ListObjects(ctx, "", 0) + result, err := client.listObjects(ctx, "", 0) require.Nil(t, err) require.Len(t, result.Objects, 3) require.False(t, result.IsTruncated) // List with no-prefix client: should see all 4 - result, err = clientNoPrefix.ListObjects(ctx, "", 0) + result, err = clientNoPrefix.listObjects(ctx, "", 0) require.Nil(t, err) require.Len(t, result.Objects, 4) } @@ -537,20 +537,20 @@ func TestClient_ListObjects_Pagination(t *testing.T) { } // List with max-keys=2 - result, err := client.ListObjects(ctx, "", 2) + result, err := client.listObjects(ctx, "", 2) require.Nil(t, err) require.Len(t, result.Objects, 2) require.True(t, result.IsTruncated) require.NotEmpty(t, result.NextContinuationToken) // Get next page - result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2) + result2, err := client.listObjects(ctx, result.NextContinuationToken, 2) require.Nil(t, err) require.Len(t, result2.Objects, 2) require.True(t, result2.IsTruncated) // Get last page - result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2) + result3, err := client.listObjects(ctx, result2.NextContinuationToken, 2) require.Nil(t, err) require.Len(t, result3.Objects, 1) require.False(t, result3.IsTruncated) @@ -744,7 +744,7 @@ func TestClient_RealBucket(t *testing.T) { prefix = "ntfy-s3-test" } - client := &Client{ + client := New(&Config{ AccessKey: accessKey, SecretKey: secretKey, Region: region, @@ -752,7 +752,7 @@ func TestClient_RealBucket(t *testing.T) { Bucket: bucket, Prefix: prefix, PathStyle: pathStyle, - } + }) ctx := context.Background() @@ -762,8 +762,7 @@ func TestClient_RealBucket(t *testing.T) { if len(existing) > 0 { keys := make([]string, len(existing)) for i, obj := range existing { - // Strip the prefix since DeleteObjects will re-add it - keys[i] = strings.TrimPrefix(obj.Key, prefix+"/") + keys[i] = obj.Key } // Batch delete in groups of 1000 for i := 0; i < len(keys); i += 1000 { @@ -807,7 +806,7 @@ func TestClient_RealBucket(t *testing.T) { t.Run("ListObjects", func(t *testing.T) { // Use a sub-prefix client for isolation - listClient := &Client{ + listClient := New(&Config{ AccessKey: accessKey, SecretKey: secretKey, Region: region, @@ -815,7 +814,7 @@ func TestClient_RealBucket(t *testing.T) { Bucket: bucket, Prefix: prefix + "/list-test", PathStyle: pathStyle, - } + }) // Put 10 objects for i := 0; i < 10; i++ { diff --git a/s3/types.go b/s3/types.go index f78c4a8b..5fec7c78 100644 --- a/s3/types.go +++ b/s3/types.go @@ -2,18 +2,36 @@ package s3 import ( "fmt" + "net/http" "time" ) // Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string. type Config struct { - Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com" - PathStyle bool - Bucket string - Prefix string - Region string - AccessKey string - SecretKey string + Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com" + PathStyle bool + Bucket string + Prefix string + Region string + AccessKey string + SecretKey string + HTTPClient *http.Client // if nil, http.DefaultClient is used +} + +// bucketURL returns the base URL for bucket-level operations. +func (c *Config) BucketURL() string { + if c.PathStyle { + return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket) + } + return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint) +} + +// hostHeader returns the value for the Host header. +func (c *Config) HostHeader() string { + if c.PathStyle { + return c.Endpoint + } + return c.Bucket + "." + c.Endpoint } // Object represents an S3 object returned by list operations. @@ -23,8 +41,8 @@ type Object struct { LastModified time.Time } -// ListResult holds the response from a ListObjectsV2 call. -type ListResult struct { +// listResult holds the response from a single ListObjectsV2 page. +type listResult struct { Objects []Object IsTruncated bool NextContinuationToken string @@ -78,10 +96,10 @@ type MultipartUpload struct { // listMultipartUploadsResult is the XML response from S3 ListMultipartUploads type listMultipartUploadsResult struct { - Uploads []listUpload `xml:"Upload"` - IsTruncated bool `xml:"IsTruncated"` - NextKeyMarker string `xml:"NextKeyMarker"` - NextUploadIDMarker string `xml:"NextUploadIdMarker"` + Uploads []listUpload `xml:"Upload"` + IsTruncated bool `xml:"IsTruncated"` + NextKeyMarker string `xml:"NextKeyMarker"` + NextUploadIDMarker string `xml:"NextUploadIdMarker"` } type listUpload struct { diff --git a/s3/util_test.go b/s3/util_test.go index d30c5664..3f08911d 100644 --- a/s3/util_test.go +++ b/s3/util_test.go @@ -105,13 +105,13 @@ func TestHmacSHA256(t *testing.T) { } func TestSignV4_SetsRequiredHeaders(t *testing.T) { - c := &Client{ + c := &Client{config: &Config{ AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", - } + }} req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil) c.signV4(req, emptyPayloadHash) @@ -131,13 +131,13 @@ func TestSignV4_SetsRequiredHeaders(t *testing.T) { } func TestSignV4_UnsignedPayload(t *testing.T) { - c := &Client{ + c := &Client{config: &Config{ AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", - } + }} req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil) c.signV4(req, unsignedPayload) @@ -146,8 +146,8 @@ func TestSignV4_UnsignedPayload(t *testing.T) { } func TestSignV4_DifferentRegions(t *testing.T) { - c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"} - c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"} + c1 := &Client{config: &Config{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}} + c2 := &Client{config: &Config{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}} req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil) c1.signV4(req1, emptyPayloadHash) From 1f270b68e0a9e9c1e33750d968eb640db9c9f13a Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Thu, 19 Mar 2026 22:42:38 -0400 Subject: [PATCH 10/32] Simplify a little, manual review --- attachment/store.go | 7 +- s3/client.go | 153 +++++++++++++++++------------------------ s3/client_auth.go | 11 +-- s3/client_multipart.go | 90 ++++++++++++------------ s3/client_test.go | 8 +-- s3/types.go | 26 +++++-- 6 files changed, 141 insertions(+), 154 deletions(-) diff --git a/attachment/store.go b/attachment/store.go index 78b8a7cc..0192b09a 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -131,7 +131,6 @@ func (c *Store) Remove(ids ...string) error { // deletes orphans (not in the valid ID set and older than 1 hour), and recomputes // the total size from the remaining objects. func (c *Store) sync() error { - log.Tag(tagStore).Debug("Sync: starting sync loop") localIDs, err := c.localIDs() if err != nil { return fmt.Errorf("attachment sync: failed to get valid IDs: %w", err) @@ -161,21 +160,21 @@ func (c *Store) sync() error { sizes[obj.ID] = obj.Size } } - log.Tag(tagStore).Debug("Sync: cache size updated to %s", util.FormatSizeHuman(size)) + log.Tag(tagStore).Debug("Attachment cache size updated to %s", util.FormatSizeHuman(size)) c.mu.Lock() c.size = size c.sizes = sizes c.mu.Unlock() // Delete orphaned attachments if len(orphanIDs) > 0 { - log.Tag(tagStore).Debug("Sync: deleting %d orphaned attachment(s)", len(orphanIDs)) + log.Tag(tagStore).Debug("Deleting %d orphaned attachment(s)", len(orphanIDs)) if err := c.backend.Delete(orphanIDs...); err != nil { return fmt.Errorf("attachment sync: failed to delete orphaned objects: %w", err) } } // Clean up incomplete uploads (S3 only) if err := c.backend.DeleteIncomplete(cutoff); err != nil { - log.Tag(tagStore).Err(err).Warn("Sync: failed to abort incomplete uploads") + log.Tag(tagStore).Err(err).Warn("Failed to abort incomplete uploads from attachment cache") } return nil } diff --git a/s3/client.go b/s3/client.go index 5c43af05..65316fb2 100644 --- a/s3/client.go +++ b/s3/client.go @@ -1,6 +1,3 @@ -// Package s3 provides a minimal S3-compatible client that works with AWS S3, DigitalOcean Spaces, -// GCP Cloud Storage, MinIO, Backblaze B2, and other S3-compatible providers. It uses raw HTTP -// requests with AWS Signature V4 signing, no AWS SDK dependency required. package s3 import ( @@ -57,32 +54,26 @@ func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) erro first := make([]byte, partSize) n, err := io.ReadFull(body, first) if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF { - log.Tag(tagS3Client).Debug("PutObject key=%s size=%d (simple)", key, n) - return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n)) + return c.putObjectSimple(ctx, key, bytes.NewReader(first[:n]), int64(n)) + } else if err != nil { + return fmt.Errorf("error reading object %s from client: %w", key, err) } - if err != nil { - return fmt.Errorf("s3: PutObject read: %w", err) - } - log.Tag(tagS3Client).Debug("PutObject key=%s (multipart)", key) - combined := io.MultiReader(bytes.NewReader(first), body) - return c.putObjectMultipart(ctx, key, combined) + return c.putObjectMultipart(ctx, key, io.MultiReader(bytes.NewReader(first), body)) } // GetObject downloads an object. The key is automatically prefixed with the client's configured // prefix. The caller must close the returned ReadCloser. func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) { - log.Tag(tagS3Client).Debug("GetObject key=%s", key) - fullKey := c.objectKey(key) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(fullKey), nil) + log.Tag(tagS3Client).Debug("Fetching object %s from backend", key) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(key), nil) if err != nil { - return nil, 0, fmt.Errorf("s3: GetObject request: %w", err) + return nil, 0, fmt.Errorf("error creating HTTP GET request for %s: %w", key, err) } c.signV4(req, emptyPayloadHash) resp, err := c.http.Do(req) if err != nil { - return nil, 0, fmt.Errorf("s3: GetObject: %w", err) - } - if !isHTTPSuccess(resp) { + return nil, 0, fmt.Errorf("error fetching object %s: %w", key, err) + } else if !isHTTPSuccess(resp) { err := parseError(resp) resp.Body.Close() return nil, 0, err @@ -97,25 +88,27 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int6 // Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present // in the response, they are returned as a combined error. func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { - log.Tag(tagS3Client).Debug("DeleteObjects keys=%d", len(keys)) - var body bytes.Buffer - body.WriteString("true") - for _, key := range keys { - body.WriteString("") - xml.EscapeText(&body, []byte(c.objectKey(key))) - body.WriteString("") + log.Tag(tagS3Client).Debug("Deleting %d object(s)", len(keys)) + req := &deleteRequest{ + Quiet: true, + } + for _, key := range keys { + req.Objects = append(req.Objects, &deleteObject{Key: c.objectKey(key)}) + } + body, err := xml.Marshal(req) + if err != nil { + return fmt.Errorf("error marshalling XML for deleting objects: %w", err) } - body.WriteString("") - bodyBytes := body.Bytes() // Content-MD5 is required by the S3 protocol for DeleteObjects requests. - md5Sum := md5.Sum(bodyBytes) //nolint:gosec - contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:]) - - respBody, err := c.doWithBodyAndHeaders(ctx, http.MethodPost, c.config.BucketURL()+"?delete=", bodyBytes, - map[string]string{"Content-MD5": contentMD5}, "DeleteObjects") + md5Sum := md5.Sum(body) //nolint:gosec + headers := map[string]string{ + "Content-MD5": base64.StdEncoding.EncodeToString(md5Sum[:]), + } + reqURL := c.config.BucketURL() + "?delete=" + respBody, err := c.do(ctx, http.MethodPost, reqURL, body, headers, "DeleteObjects") if err != nil { - return err + return fmt.Errorf("error deleting objects: %w", err) } // S3 may return HTTP 200 with per-key errors in the response body @@ -128,7 +121,7 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { for _, e := range result.Errors { msgs = append(msgs, fmt.Sprintf("%s: %s", e.Key, e.Message)) } - return fmt.Errorf("s3: DeleteObjects partial failure: %s", strings.Join(msgs, "; ")) + return fmt.Errorf("error deleting objects, partial failure: %s", strings.Join(msgs, "; ")) } return nil } @@ -147,7 +140,7 @@ func (c *Client) listObjects(ctx context.Context, continuationToken string, maxK if maxKeys > 0 { query.Set("max-keys", strconv.Itoa(maxKeys)) } - respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListObjects") + respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil, "ListObjects") if err != nil { return nil, err } @@ -198,12 +191,12 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) } -// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. -func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error { - fullKey := c.objectKey(key) - req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(fullKey), body) +// putObjectSimple uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. +func (c *Client) putObjectSimple(ctx context.Context, key string, body io.Reader, size int64) error { + log.Tag(tagS3Client).Debug("Uploading object %s (%d bytes)", key, size) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(key), body) if err != nil { - return fmt.Errorf("s3: PutObject request: %w", err) + return fmt.Errorf("uploading object %s failed: %w", key, err) } req.ContentLength = size c.signV4(req, unsignedPayload) @@ -218,50 +211,32 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size return nil } -// do creates a request, signs it with an empty payload, executes it, reads the response body, -// and checks for errors. It is used for bodiless GET/POST requests. -func (c *Client) do(ctx context.Context, method, reqURL string, body io.Reader, op string) ([]byte, error) { - req, err := http.NewRequestWithContext(ctx, method, reqURL, body) +// do creates a signed request, executes it, reads the response body, and checks for errors. +// If body is nil, the request is sent with an empty payload. If body is non-nil, it is sent +// with a computed SHA-256 payload hash and Content-Type: application/xml. +func (c *Client) do(ctx context.Context, method, reqURL string, body []byte, headers map[string]string, op string) ([]byte, error) { + var reader io.Reader + var hash string + if body != nil { + reader = bytes.NewReader(body) + hash = sha256Hex(body) + } else { + hash = emptyPayloadHash + } + req, err := http.NewRequestWithContext(ctx, method, reqURL, reader) if err != nil { return nil, fmt.Errorf("s3: %s request: %w", op, err) } - if body == nil { + if body != nil { + req.ContentLength = int64(len(body)) + req.Header.Set("Content-Type", "application/xml") + } else { req.ContentLength = 0 } - c.signV4(req, emptyPayloadHash) - resp, err := c.http.Do(req) - if err != nil { - return nil, fmt.Errorf("s3: %s: %w", op, err) - } - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("s3: %s read: %w", op, err) - } - if !isHTTPSuccess(resp) { - return nil, parseErrorFromBytes(resp.StatusCode, respBody) - } - return respBody, nil -} - -// doWithBody is like do, but sends a body with a computed SHA-256 payload hash and Content-Type: application/xml. -func (c *Client) doWithBody(ctx context.Context, method, reqURL string, bodyBytes []byte, op string) ([]byte, error) { - return c.doWithBodyAndHeaders(ctx, method, reqURL, bodyBytes, nil, op) -} - -// doWithBodyAndHeaders is like doWithBody, but allows setting additional headers (e.g. Content-MD5). -func (c *Client) doWithBodyAndHeaders(ctx context.Context, method, reqURL string, bodyBytes []byte, headers map[string]string, op string) ([]byte, error) { - payloadHash := sha256Hex(bodyBytes) - req, err := http.NewRequestWithContext(ctx, method, reqURL, bytes.NewReader(bodyBytes)) - if err != nil { - return nil, fmt.Errorf("s3: %s request: %w", op, err) - } - req.ContentLength = int64(len(bodyBytes)) - req.Header.Set("Content-Type", "application/xml") for k, v := range headers { req.Header.Set(k, v) } - c.signV4(req, payloadHash) + c.signV4(req, hash) resp, err := c.http.Do(req) if err != nil { return nil, fmt.Errorf("s3: %s: %w", op, err) @@ -277,14 +252,6 @@ func (c *Client) doWithBodyAndHeaders(ctx context.Context, method, reqURL string return respBody, nil } -// objectKey prepends the configured prefix to the given key. -func (c *Client) objectKey(key string) string { - if c.config.Prefix != "" { - return c.config.Prefix + "/" + key - } - return key -} - // prefixForList returns the prefix to use in ListObjectsV2 requests, // with a trailing slash so that only objects under the prefix directory are returned. func (c *Client) prefixForList() string { @@ -303,12 +270,16 @@ func (c *Client) stripPrefix(key string) string { return key } -// objectURL returns the full URL for an object (key should already include the prefix). -// Each path segment is URI-encoded to handle special characters in keys. -func (c *Client) objectURL(key string) string { - segments := strings.Split(key, "/") - for i, seg := range segments { - segments[i] = uriEncode(seg) +// objectKey prepends the configured prefix to the given key. +func (c *Client) objectKey(key string) string { + if c.config.Prefix != "" { + return c.config.Prefix + "/" + key } - return c.config.BucketURL() + "/" + strings.Join(segments, "/") + return key +} + +// objectURL returns the full URL for an object, automatically prepending the configured prefix. +func (c *Client) objectURL(key string) string { + u, _ := url.JoinPath(c.config.BucketURL(), c.objectKey(key)) + return u } diff --git a/s3/client_auth.go b/s3/client_auth.go index ede971f3..8b7e6053 100644 --- a/s3/client_auth.go +++ b/s3/client_auth.go @@ -11,7 +11,7 @@ import ( // signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 // of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. -func (c *Client) signV4(req *http.Request, payloadHash string) { +func (c *Client) signV4(req *http.Request, hash string) { now := time.Now().UTC() datestamp := now.Format("20060102") amzDate := now.Format("20060102T150405Z") @@ -19,7 +19,7 @@ func (c *Client) signV4(req *http.Request, payloadHash string) { // Required headers req.Header.Set("Host", c.config.HostHeader()) req.Header.Set("X-Amz-Date", amzDate) - req.Header.Set("X-Amz-Content-Sha256", payloadHash) + req.Header.Set("X-Amz-Content-Sha256", hash) // Canonical headers (all headers we set, sorted by lowercase key) signedKeys := make([]string, 0, len(req.Header)) @@ -46,7 +46,7 @@ func (c *Client) signV4(req *http.Request, payloadHash string) { canonicalQueryString(req.URL.Query()), chBuf.String(), signedHeadersStr, - payloadHash, + hash, }, "\n") // String to sign @@ -61,8 +61,9 @@ func (c *Client) signV4(req *http.Request, payloadHash string) { []byte("aws4_request")) signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) - req.Header.Set("Authorization", fmt.Sprintf( + header := fmt.Sprintf( "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", c.config.AccessKey, credentialScope, signedHeadersStr, signature, - )) + ) + req.Header.Set("Authorization", header) } diff --git a/s3/client_multipart.go b/s3/client_multipart.go index 61d206b6..d58a9337 100644 --- a/s3/client_multipart.go +++ b/s3/client_multipart.go @@ -14,9 +14,25 @@ import ( "heckel.io/ntfy/v2/log" ) -// ListMultipartUploads returns in-progress multipart uploads for the client's prefix. +// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated +// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads. +func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error { + uploads, err := c.listMultipartUploads(ctx) + if err != nil { + return err + } + for _, u := range uploads { + if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) { + log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated) + c.abortMultipartUpload(ctx, u.Key, u.UploadID) + } + } + return nil +} + +// listMultipartUploads returns in-progress multipart uploads for the client's prefix. // It paginates automatically, stopping after 10,000 pages as a safety valve. -func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) { +func (c *Client) listMultipartUploads(ctx context.Context) ([]MultipartUpload, error) { var all []MultipartUpload var keyMarker, uploadIDMarker string for page := 0; page < maxPages; page++ { @@ -28,13 +44,13 @@ func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, e query.Set("key-marker", keyMarker) query.Set("upload-id-marker", uploadIDMarker) } - respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListMultipartUploads") + respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil, "listMultipartUploads") if err != nil { return nil, err } var result listMultipartUploadsResult if err := xml.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err) + return nil, fmt.Errorf("s3: listMultipartUploads XML: %w", err) } for _, u := range result.Uploads { var initiated time.Time @@ -53,33 +69,17 @@ func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, e keyMarker = result.NextKeyMarker uploadIDMarker = result.NextUploadIDMarker } - return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages) -} - -// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated -// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads. -func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error { - uploads, err := c.ListMultipartUploads(ctx) - if err != nil { - return err - } - for _, u := range uploads { - if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) { - log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated) - c.abortMultipartUpload(ctx, u.Key, u.UploadID) - } - } - return nil + return nil, fmt.Errorf("s3: listMultipartUploads exceeded %d pages", maxPages) } // putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize // chunks, uploading each as a separate part. This allows uploading without knowing the total // body size in advance. func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error { - fullKey := c.objectKey(key) + log.Tag(tagS3Client).Debug("Uploading multipart object %s", key) // Step 1: Initiate multipart upload - uploadID, err := c.initiateMultipartUpload(ctx, fullKey) + uploadID, err := c.initiateMultipartUpload(ctx, key) if err != nil { return err } @@ -91,9 +91,9 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea for { n, err := io.ReadFull(body, buf) if n > 0 { - etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n]) + etag, uploadErr := c.uploadPart(ctx, key, uploadID, partNumber, buf[:n]) if uploadErr != nil { - c.abortMultipartUpload(ctx, fullKey, uploadID) + c.abortMultipartUpload(ctx, key, uploadID) return uploadErr } parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag}) @@ -103,18 +103,18 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea break } if err != nil { - c.abortMultipartUpload(ctx, fullKey, uploadID) + c.abortMultipartUpload(ctx, key, uploadID) return fmt.Errorf("s3: PutObject read: %w", err) } } // Step 3: Complete multipart upload - return c.completeMultipartUpload(ctx, fullKey, uploadID, parts) + return c.completeMultipartUpload(ctx, key, uploadID, parts) } // initiateMultipartUpload starts a new multipart upload and returns the upload ID. -func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) { - respBody, err := c.do(ctx, http.MethodPost, c.objectURL(fullKey)+"?uploads", nil, "InitiateMultipartUpload") +func (c *Client) initiateMultipartUpload(ctx context.Context, key string) (string, error) { + respBody, err := c.do(ctx, http.MethodPost, c.objectURL(key)+"?uploads", nil, nil, "InitiateMultipartUpload") if err != nil { return "", err } @@ -122,14 +122,14 @@ func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (s if err := xml.Unmarshal(respBody, &result); err != nil { return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) } - log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID) + log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", key, result.UploadID) return result.UploadID, nil } // uploadPart uploads a single part of a multipart upload and returns the ETag. -func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) { - log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data)) - reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID)) +func (c *Client) uploadPart(ctx context.Context, key, uploadID string, partNumber int, data []byte) (string, error) { + log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", key, partNumber, len(data)) + reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(key), partNumber, url.QueryEscape(uploadID)) req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) if err != nil { return "", fmt.Errorf("s3: UploadPart request: %w", err) @@ -149,17 +149,15 @@ func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partN } // completeMultipartUpload finalizes a multipart upload with the given parts. -func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error { - log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts)) - var body bytes.Buffer - body.WriteString("") - for _, p := range parts { - fmt.Fprintf(&body, "%d%s", p.PartNumber, p.ETag) +func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID string, parts []completedPart) error { + log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", key, uploadID, len(parts)) + bodyBytes, err := xml.Marshal(completeMultipartUploadRequest{Parts: parts}) + if err != nil { + return fmt.Errorf("s3: CompleteMultipartUpload marshal: %w", err) } - body.WriteString("") - respBody, err := c.doWithBody(ctx, http.MethodPost, - fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)), - body.Bytes(), "CompleteMultipartUpload") + respBody, err := c.do(ctx, http.MethodPost, + fmt.Sprintf("%s?uploadId=%s", c.objectURL(key), url.QueryEscape(uploadID)), + bodyBytes, nil, "CompleteMultipartUpload") if err != nil { return err } @@ -172,9 +170,9 @@ func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID } // abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. -func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) { - log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID) - reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) +func (c *Client) abortMultipartUpload(ctx context.Context, key, uploadID string) { + log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", key, uploadID) + reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(key), url.QueryEscape(uploadID)) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) if err != nil { return diff --git a/s3/client_test.go b/s3/client_test.go index 447f36d0..2568bf28 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -374,13 +374,13 @@ func TestConfig_BucketURL_VirtualHosted(t *testing.T) { } func TestClient_ObjectURL_PathStyle(t *testing.T) { - c := &Client{config: &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}} - require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj")) + c := &Client{config: &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", Prefix: "prefix", PathStyle: true}} + require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("obj")) } func TestClient_ObjectURL_VirtualHosted(t *testing.T) { - c := &Client{config: &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}} - require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj")) + c := &Client{config: &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", Prefix: "prefix", PathStyle: false}} + require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("obj")) } func TestConfig_HostHeader_PathStyle(t *testing.T) { diff --git a/s3/types.go b/s3/types.go index 5fec7c78..23ccb15b 100644 --- a/s3/types.go +++ b/s3/types.go @@ -1,6 +1,7 @@ package s3 import ( + "encoding/xml" "fmt" "net/http" "time" @@ -76,6 +77,17 @@ type listObject struct { LastModified string `xml:"LastModified"` } +// deleteRequest is the XML request body for S3 DeleteObjects +type deleteRequest struct { + XMLName xml.Name `xml:"Delete"` + Quiet bool `xml:"Quiet"` + Objects []*deleteObject `xml:"Object"` +} + +type deleteObject struct { + Key string `xml:"Key"` +} + // deleteResult is the XML response from S3 DeleteObjects type deleteResult struct { Errors []deleteError `xml:"Error"` @@ -87,14 +99,14 @@ type deleteError struct { Message string `xml:"Message"` } -// MultipartUpload represents an in-progress multipart upload returned by ListMultipartUploads. +// MultipartUpload represents an in-progress multipart upload returned by listMultipartUploads. type MultipartUpload struct { Key string UploadID string Initiated time.Time } -// listMultipartUploadsResult is the XML response from S3 ListMultipartUploads +// listMultipartUploadsResult is the XML response from S3 listMultipartUploads type listMultipartUploadsResult struct { Uploads []listUpload `xml:"Upload"` IsTruncated bool `xml:"IsTruncated"` @@ -113,8 +125,14 @@ type initiateMultipartUploadResult struct { UploadID string `xml:"UploadId"` } +// completeMultipartUploadRequest is the XML request body for S3 CompleteMultipartUpload +type completeMultipartUploadRequest struct { + XMLName xml.Name `xml:"CompleteMultipartUpload"` + Parts []completedPart `xml:"Part"` +} + // completedPart represents a successfully uploaded part for CompleteMultipartUpload type completedPart struct { - PartNumber int - ETag string + PartNumber int `xml:"PartNumber"` + ETag string `xml:"ETag"` } From 02ea09ab0ff572ea759d090f7a7c33be7c84f7b5 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 15:52:45 -0400 Subject: [PATCH 11/32] Refine, manual review, re-org --- attachment/backend_s3.go | 2 +- s3/client.go | 231 +++++++++++++++++---------------------- s3/client_auth.go | 2 + s3/client_multipart.go | 24 ++-- s3/client_test.go | 52 ++++----- s3/types.go | 78 +++++++++---- s3/util.go | 8 +- tools/s3cli/main.go | 2 +- 8 files changed, 206 insertions(+), 193 deletions(-) diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index 6603bb91..eb911edc 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -33,7 +33,7 @@ func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) { } func (b *s3Backend) List() ([]object, error) { - objects, err := b.client.ListAllObjects(context.Background()) + objects, err := b.client.ListObjectsV2(context.Background()) if err != nil { return nil, err } diff --git a/s3/client.go b/s3/client.go index 65316fb2..754a1bfb 100644 --- a/s3/client.go +++ b/s3/client.go @@ -47,9 +47,10 @@ func New(config *Config) *Client { // PutObject uploads body to the given key. The key is automatically prefixed with the client's // configured prefix. The body size does not need to be known in advance. // -// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request. -// Otherwise, the body is uploaded using S3 multipart upload, reading one part at a time -// into memory. +// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request +// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html). Otherwise, the body +// is uploaded using S3 multipart upload, reading one part at a time into memory +// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html). func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error { first := make([]byte, partSize) n, err := io.ReadFull(body, first) @@ -61,11 +62,33 @@ func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) erro return c.putObjectMultipart(ctx, key, io.MultiReader(bytes.NewReader(first), body)) } +// putObjectSimple uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. +func (c *Client) putObjectSimple(ctx context.Context, key string, body io.Reader, size int64) error { + log.Tag(tagS3Client).Debug("Uploading object %s (%d bytes)", key, size) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.config.ObjectURL(key), body) + if err != nil { + return fmt.Errorf("creating upload request object %s failed: %w", key, err) + } + req.ContentLength = size + c.signV4(req, unsignedPayload) + resp, err := c.http.Do(req) + if err != nil { + return fmt.Errorf("uploading object %s failed: %w", key, err) + } + resp.Body.Close() + if !isHTTPSuccess(resp) { + return parseError(resp) + } + return nil +} + // GetObject downloads an object. The key is automatically prefixed with the client's configured // prefix. The caller must close the returned ReadCloser. +// +// See https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int64, error) { - log.Tag(tagS3Client).Debug("Fetching object %s from backend", key) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.objectURL(key), nil) + log.Tag(tagS3Client).Debug("Fetching object %s", key) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.config.ObjectURL(key), nil) if err != nil { return nil, 0, fmt.Errorf("error creating HTTP GET request for %s: %w", key, err) } @@ -81,19 +104,88 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int6 return resp.Body, resp.ContentLength, nil } +// ListObjectsV2 returns all objects under the client's configured prefix by paginating through +// ListObjectsV2 results automatically. Keys in the returned objects have the prefix stripped, +// so they match the keys used with PutObject/GetObject/DeleteObjects. It stops after 10,000 +// pages as a safety valve. +// +// See https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html +func (c *Client) ListObjectsV2(ctx context.Context) ([]*Object, error) { + var all []*Object + var token string + for page := 0; page < maxPages; page++ { + result, err := c.listObjectsV2(ctx, token, 0) + if err != nil { + return nil, err + } + for _, obj := range result.Objects { + obj.Key = c.config.StripPrefix(obj.Key) + all = append(all, obj) + } + if !result.IsTruncated { + return all, nil + } + token = result.NextContinuationToken + } + return nil, fmt.Errorf("listing objects exceeded %d pages", maxPages) +} + +// listObjectsV2 performs a single ListObjectsV2 request using the client's configured prefix. +// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000). +func (c *Client) listObjectsV2(ctx context.Context, continuationToken string, maxKeys int) (*listObjectsV2Result, error) { + log.Tag(tagS3Client).Debug("Listing remote objects with continuation token '%s'", continuationToken) + query := url.Values{"list-type": {"2"}} + if prefix := c.config.ListPrefix(); prefix != "" { + query.Set("prefix", prefix) + } + if continuationToken != "" { + query.Set("continuation-token", continuationToken) + } + if maxKeys > 0 { + query.Set("max-keys", strconv.Itoa(maxKeys)) + } + respBody, err := c.do(ctx, "ListObjects", http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil) + if err != nil { + return nil, err + } + var result listObjectsV2Response + if err := xml.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal list object response: %w", err) + } + objects := make([]*Object, len(result.Contents)) + for i, obj := range result.Contents { + var lastModified time.Time + if obj.LastModified != "" { + lastModified, _ = time.Parse(time.RFC3339, obj.LastModified) + } + objects[i] = &Object{ + Key: obj.Key, + Size: obj.Size, + LastModified: lastModified, + } + } + return &listObjectsV2Result{ + Objects: objects, + IsTruncated: result.IsTruncated, + NextContinuationToken: result.NextContinuationToken, + }, nil +} + // DeleteObjects removes multiple objects in a single batch request. Keys are automatically // prefixed with the client's configured prefix. S3 supports up to 1000 keys per call; the // caller is responsible for batching if needed. // // Even when S3 returns HTTP 200, individual keys may fail. If any per-key errors are present // in the response, they are returned as a combined error. +// +// See https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { log.Tag(tagS3Client).Debug("Deleting %d object(s)", len(keys)) - req := &deleteRequest{ + req := &deleteObjectsRequest{ Quiet: true, } for _, key := range keys { - req.Objects = append(req.Objects, &deleteObject{Key: c.objectKey(key)}) + req.Objects = append(req.Objects, &deleteObject{Key: c.config.ObjectKey(key)}) } body, err := xml.Marshal(req) if err != nil { @@ -105,14 +197,14 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { headers := map[string]string{ "Content-MD5": base64.StdEncoding.EncodeToString(md5Sum[:]), } - reqURL := c.config.BucketURL() + "?delete=" - respBody, err := c.do(ctx, http.MethodPost, reqURL, body, headers, "DeleteObjects") + reqURL := c.config.BucketURL() + "?delete" + respBody, err := c.do(ctx, "DeleteObjects", http.MethodPost, reqURL, body, headers) if err != nil { return fmt.Errorf("error deleting objects: %w", err) } // S3 may return HTTP 200 with per-key errors in the response body - var result deleteResult + var result deleteObjectsResult if err := xml.Unmarshal(respBody, &result); err != nil { return nil // If we can't parse, assume success (Quiet mode returns empty body on success) } @@ -126,95 +218,10 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { return nil } -// listObjects performs a single ListObjectsV2 request using the client's configured prefix. -// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000). -func (c *Client) listObjects(ctx context.Context, continuationToken string, maxKeys int) (*listResult, error) { - log.Tag(tagS3Client).Debug("ListObjects continuation=%s maxKeys=%d", continuationToken, maxKeys) - query := url.Values{"list-type": {"2"}} - if prefix := c.prefixForList(); prefix != "" { - query.Set("prefix", prefix) - } - if continuationToken != "" { - query.Set("continuation-token", continuationToken) - } - if maxKeys > 0 { - query.Set("max-keys", strconv.Itoa(maxKeys)) - } - respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil, "ListObjects") - if err != nil { - return nil, err - } - var result listObjectsV2Response - if err := xml.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("s3: ListObjects XML: %w", err) - } - objects := make([]Object, len(result.Contents)) - for i, obj := range result.Contents { - var lastModified time.Time - if obj.LastModified != "" { - lastModified, _ = time.Parse(time.RFC3339, obj.LastModified) - } - objects[i] = Object{ - Key: obj.Key, - Size: obj.Size, - LastModified: lastModified, - } - } - return &listResult{ - Objects: objects, - IsTruncated: result.IsTruncated, - NextContinuationToken: result.NextContinuationToken, - }, nil -} - -// ListAllObjects returns all objects under the client's configured prefix by paginating through -// ListObjectsV2 results automatically. Keys in the returned objects have the prefix stripped, -// so they match the keys used with PutObject/GetObject/DeleteObjects. It stops after 10,000 -// pages as a safety valve. -func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { - var all []Object - var token string - for page := 0; page < maxPages; page++ { - result, err := c.listObjects(ctx, token, 0) - if err != nil { - return nil, err - } - for _, obj := range result.Objects { - obj.Key = c.stripPrefix(obj.Key) - all = append(all, obj) - } - if !result.IsTruncated { - return all, nil - } - token = result.NextContinuationToken - } - return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) -} - -// putObjectSimple uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. -func (c *Client) putObjectSimple(ctx context.Context, key string, body io.Reader, size int64) error { - log.Tag(tagS3Client).Debug("Uploading object %s (%d bytes)", key, size) - req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.objectURL(key), body) - if err != nil { - return fmt.Errorf("uploading object %s failed: %w", key, err) - } - req.ContentLength = size - c.signV4(req, unsignedPayload) - resp, err := c.http.Do(req) - if err != nil { - return fmt.Errorf("s3: PutObject: %w", err) - } - resp.Body.Close() - if !isHTTPSuccess(resp) { - return parseError(resp) - } - return nil -} - // do creates a signed request, executes it, reads the response body, and checks for errors. // If body is nil, the request is sent with an empty payload. If body is non-nil, it is sent // with a computed SHA-256 payload hash and Content-Type: application/xml. -func (c *Client) do(ctx context.Context, method, reqURL string, body []byte, headers map[string]string, op string) ([]byte, error) { +func (c *Client) do(ctx context.Context, op, method, reqURL string, body []byte, headers map[string]string) ([]byte, error) { var reader io.Reader var hash string if body != nil { @@ -251,35 +258,3 @@ func (c *Client) do(ctx context.Context, method, reqURL string, body []byte, hea } return respBody, nil } - -// prefixForList returns the prefix to use in ListObjectsV2 requests, -// with a trailing slash so that only objects under the prefix directory are returned. -func (c *Client) prefixForList() string { - if c.config.Prefix != "" { - return c.config.Prefix + "/" - } - return "" -} - -// stripPrefix removes the configured prefix from a key returned by ListObjectsV2, -// so keys match what was passed to PutObject/GetObject/DeleteObjects. -func (c *Client) stripPrefix(key string) string { - if c.config.Prefix != "" { - return strings.TrimPrefix(key, c.config.Prefix+"/") - } - return key -} - -// objectKey prepends the configured prefix to the given key. -func (c *Client) objectKey(key string) string { - if c.config.Prefix != "" { - return c.config.Prefix + "/" + key - } - return key -} - -// objectURL returns the full URL for an object, automatically prepending the configured prefix. -func (c *Client) objectURL(key string) string { - u, _ := url.JoinPath(c.config.BucketURL(), c.objectKey(key)) - return u -} diff --git a/s3/client_auth.go b/s3/client_auth.go index 8b7e6053..61aba73c 100644 --- a/s3/client_auth.go +++ b/s3/client_auth.go @@ -11,6 +11,8 @@ import ( // signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 // of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. +// +// See https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-authenticating-requests.html func (c *Client) signV4(req *http.Request, hash string) { now := time.Now().UTC() datestamp := now.Format("20060102") diff --git a/s3/client_multipart.go b/s3/client_multipart.go index d58a9337..f6b68784 100644 --- a/s3/client_multipart.go +++ b/s3/client_multipart.go @@ -16,6 +16,9 @@ import ( // AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated // before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads. +// +// See https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html +// and https://docs.aws.amazon.com/AmazonS3/latest/API/API_AbortMultipartUpload.html func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error { uploads, err := c.listMultipartUploads(ctx) if err != nil { @@ -32,19 +35,19 @@ func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) e // listMultipartUploads returns in-progress multipart uploads for the client's prefix. // It paginates automatically, stopping after 10,000 pages as a safety valve. -func (c *Client) listMultipartUploads(ctx context.Context) ([]MultipartUpload, error) { - var all []MultipartUpload +func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload, error) { + var all []*multipartUpload var keyMarker, uploadIDMarker string for page := 0; page < maxPages; page++ { query := url.Values{"uploads": {""}} - if prefix := c.prefixForList(); prefix != "" { + if prefix := c.config.ListPrefix(); prefix != "" { query.Set("prefix", prefix) } if keyMarker != "" { query.Set("key-marker", keyMarker) query.Set("upload-id-marker", uploadIDMarker) } - respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil, "listMultipartUploads") + respBody, err := c.do(ctx, "listMultipartUploads", http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil) if err != nil { return nil, err } @@ -57,7 +60,7 @@ func (c *Client) listMultipartUploads(ctx context.Context) ([]MultipartUpload, e if u.Initiated != "" { initiated, _ = time.Parse(time.RFC3339, u.Initiated) } - all = append(all, MultipartUpload{ + all = append(all, &multipartUpload{ Key: u.Key, UploadID: u.UploadID, Initiated: initiated, @@ -114,7 +117,7 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea // initiateMultipartUpload starts a new multipart upload and returns the upload ID. func (c *Client) initiateMultipartUpload(ctx context.Context, key string) (string, error) { - respBody, err := c.do(ctx, http.MethodPost, c.objectURL(key)+"?uploads", nil, nil, "InitiateMultipartUpload") + respBody, err := c.do(ctx, "InitiateMultipartUpload", http.MethodPost, c.config.ObjectURL(key)+"?uploads", nil, nil) if err != nil { return "", err } @@ -129,7 +132,7 @@ func (c *Client) initiateMultipartUpload(ctx context.Context, key string) (strin // uploadPart uploads a single part of a multipart upload and returns the ETag. func (c *Client) uploadPart(ctx context.Context, key, uploadID string, partNumber int, data []byte) (string, error) { log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", key, partNumber, len(data)) - reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(key), partNumber, url.QueryEscape(uploadID)) + reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.config.ObjectURL(key), partNumber, url.QueryEscape(uploadID)) req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) if err != nil { return "", fmt.Errorf("s3: UploadPart request: %w", err) @@ -155,9 +158,8 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri if err != nil { return fmt.Errorf("s3: CompleteMultipartUpload marshal: %w", err) } - respBody, err := c.do(ctx, http.MethodPost, - fmt.Sprintf("%s?uploadId=%s", c.objectURL(key), url.QueryEscape(uploadID)), - bodyBytes, nil, "CompleteMultipartUpload") + reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID)) + respBody, err := c.do(ctx, "CompleteMultipartUpload", http.MethodPost, reqURL, bodyBytes, nil) if err != nil { return err } @@ -172,7 +174,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri // abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. func (c *Client) abortMultipartUpload(ctx context.Context, key, uploadID string) { log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", key, uploadID) - reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(key), url.QueryEscape(uploadID)) + reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID)) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) if err != nil { return diff --git a/s3/client_test.go b/s3/client_test.go index 2568bf28..d488c832 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -373,14 +373,14 @@ func TestConfig_BucketURL_VirtualHosted(t *testing.T) { require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.BucketURL()) } -func TestClient_ObjectURL_PathStyle(t *testing.T) { - c := &Client{config: &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", Prefix: "prefix", PathStyle: true}} - require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("obj")) +func TestConfig_ObjectURL_PathStyle(t *testing.T) { + c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", Prefix: "prefix", PathStyle: true} + require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.ObjectURL("obj")) } -func TestClient_ObjectURL_VirtualHosted(t *testing.T) { - c := &Client{config: &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", Prefix: "prefix", PathStyle: false}} - require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("obj")) +func TestConfig_ObjectURL_VirtualHosted(t *testing.T) { + c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", Prefix: "prefix", PathStyle: false} + require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.ObjectURL("obj")) } func TestConfig_HostHeader_PathStyle(t *testing.T) { @@ -393,20 +393,20 @@ func TestConfig_HostHeader_VirtualHosted(t *testing.T) { require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.HostHeader()) } -func TestClient_ObjectKey(t *testing.T) { - c := &Client{config: &Config{Prefix: "attachments"}} - require.Equal(t, "attachments/file123", c.objectKey("file123")) +func TestConfig_ObjectKey(t *testing.T) { + c := &Config{Prefix: "attachments"} + require.Equal(t, "attachments/file123", c.ObjectKey("file123")) - c2 := &Client{config: &Config{Prefix: ""}} - require.Equal(t, "file123", c2.objectKey("file123")) + c2 := &Config{Prefix: ""} + require.Equal(t, "file123", c2.ObjectKey("file123")) } -func TestClient_PrefixForList(t *testing.T) { - c := &Client{config: &Config{Prefix: "attachments"}} - require.Equal(t, "attachments/", c.prefixForList()) +func TestConfig_ListPrefix(t *testing.T) { + c := &Config{Prefix: "attachments"} + require.Equal(t, "attachments/", c.ListPrefix()) - c2 := &Client{config: &Config{Prefix: ""}} - require.Equal(t, "", c2.prefixForList()) + c2 := &Config{Prefix: ""} + require.Equal(t, "", c2.ListPrefix()) } // --- Integration tests using mock S3 server --- @@ -512,13 +512,13 @@ func TestClient_ListObjects(t *testing.T) { require.Nil(t, err) // List with prefix client: should only see 3 - result, err := client.listObjects(ctx, "", 0) + result, err := client.listObjectsV2(ctx, "", 0) require.Nil(t, err) require.Len(t, result.Objects, 3) require.False(t, result.IsTruncated) // List with no-prefix client: should see all 4 - result, err = clientNoPrefix.listObjects(ctx, "", 0) + result, err = clientNoPrefix.listObjectsV2(ctx, "", 0) require.Nil(t, err) require.Len(t, result.Objects, 4) } @@ -537,20 +537,20 @@ func TestClient_ListObjects_Pagination(t *testing.T) { } // List with max-keys=2 - result, err := client.listObjects(ctx, "", 2) + result, err := client.listObjectsV2(ctx, "", 2) require.Nil(t, err) require.Len(t, result.Objects, 2) require.True(t, result.IsTruncated) require.NotEmpty(t, result.NextContinuationToken) // Get next page - result2, err := client.listObjects(ctx, result.NextContinuationToken, 2) + result2, err := client.listObjectsV2(ctx, result.NextContinuationToken, 2) require.Nil(t, err) require.Len(t, result2.Objects, 2) require.True(t, result2.IsTruncated) // Get last page - result3, err := client.listObjects(ctx, result2.NextContinuationToken, 2) + result3, err := client.listObjectsV2(ctx, result2.NextContinuationToken, 2) require.Nil(t, err) require.Len(t, result3.Objects, 1) require.False(t, result3.IsTruncated) @@ -568,7 +568,7 @@ func TestClient_ListAllObjects(t *testing.T) { require.Nil(t, err) } - objects, err := client.ListAllObjects(ctx) + objects, err := client.ListObjectsV2(ctx) require.Nil(t, err) require.Len(t, objects, 10) } @@ -688,7 +688,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) { } // List all 20k objects with pagination - objects, err := client.ListAllObjects(ctx) + objects, err := client.ListObjectsV2(ctx) require.Nil(t, err) require.Len(t, objects, numObjects) @@ -708,7 +708,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) { require.Nil(t, err) // List again: should have 19000 - objects, err = client.ListAllObjects(ctx) + objects, err = client.ListObjectsV2(ctx) require.Nil(t, err) require.Len(t, objects, numObjects-1000) } @@ -757,7 +757,7 @@ func TestClient_RealBucket(t *testing.T) { ctx := context.Background() // Clean up any leftover objects from previous runs - existing, err := client.ListAllObjects(ctx) + existing, err := client.ListObjectsV2(ctx) require.Nil(t, err) if len(existing) > 0 { keys := make([]string, len(existing)) @@ -823,7 +823,7 @@ func TestClient_RealBucket(t *testing.T) { } // List - objects, err := listClient.ListAllObjects(ctx) + objects, err := listClient.ListObjectsV2(ctx) require.Nil(t, err) require.Len(t, objects, 10) diff --git a/s3/types.go b/s3/types.go index 23ccb15b..a3694bd4 100644 --- a/s3/types.go +++ b/s3/types.go @@ -4,6 +4,8 @@ import ( "encoding/xml" "fmt" "net/http" + "net/url" + "strings" "time" ) @@ -19,7 +21,7 @@ type Config struct { HTTPClient *http.Client // if nil, http.DefaultClient is used } -// bucketURL returns the base URL for bucket-level operations. +// BucketURL returns the base URL for bucket-level operations. func (c *Config) BucketURL() string { if c.PathStyle { return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket) @@ -27,7 +29,7 @@ func (c *Config) BucketURL() string { return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint) } -// hostHeader returns the value for the Host header. +// HostHeader returns the value for the Host header. func (c *Config) HostHeader() string { if c.PathStyle { return c.Endpoint @@ -35,6 +37,38 @@ func (c *Config) HostHeader() string { return c.Bucket + "." + c.Endpoint } +// ListPrefix returns the prefix to use in ListObjectsV2 requests, +// with a trailing slash so that only objects under the prefix directory are returned. +func (c *Config) ListPrefix() string { + if c.Prefix != "" { + return c.Prefix + "/" + } + return "" +} + +// StripPrefix removes the configured prefix from a key returned by ListObjectsV2, +// so keys match what was passed to PutObject/GetObject/DeleteObjects. +func (c *Config) StripPrefix(key string) string { + if c.Prefix != "" { + return strings.TrimPrefix(key, c.Prefix+"/") + } + return key +} + +// ObjectKey prepends the configured prefix to the given key. +func (c *Config) ObjectKey(key string) string { + if c.Prefix != "" { + return c.Prefix + "/" + key + } + return key +} + +// ObjectURL returns the full URL for an object, automatically prepending the configured prefix. +func (c *Config) ObjectURL(key string) string { + u, _ := url.JoinPath(c.BucketURL(), c.ObjectKey(key)) + return u +} + // Object represents an S3 object returned by list operations. type Object struct { Key string @@ -42,13 +76,6 @@ type Object struct { LastModified time.Time } -// listResult holds the response from a single ListObjectsV2 page. -type listResult struct { - Objects []Object - IsTruncated bool - NextContinuationToken string -} - // ErrorResponse is returned when S3 responds with a non-2xx status code. type ErrorResponse struct { StatusCode int @@ -66,9 +93,16 @@ func (e *ErrorResponse) Error() string { // listObjectsV2Response is the XML response from S3 ListObjectsV2 type listObjectsV2Response struct { - Contents []listObject `xml:"Contents"` - IsTruncated bool `xml:"IsTruncated"` - NextContinuationToken string `xml:"NextContinuationToken"` + Contents []*listObject `xml:"Contents"` + IsTruncated bool `xml:"IsTruncated"` + NextContinuationToken string `xml:"NextContinuationToken"` +} + +// listObjectsV2Result holds the response from a single ListObjectsV2 page. +type listObjectsV2Result struct { + Objects []*Object + IsTruncated bool + NextContinuationToken string } type listObject struct { @@ -77,8 +111,8 @@ type listObject struct { LastModified string `xml:"LastModified"` } -// deleteRequest is the XML request body for S3 DeleteObjects -type deleteRequest struct { +// deleteObjectsRequest is the XML request body for S3 DeleteObjects +type deleteObjectsRequest struct { XMLName xml.Name `xml:"Delete"` Quiet bool `xml:"Quiet"` Objects []*deleteObject `xml:"Object"` @@ -88,8 +122,8 @@ type deleteObject struct { Key string `xml:"Key"` } -// deleteResult is the XML response from S3 DeleteObjects -type deleteResult struct { +// deleteObjectsResult is the XML response from S3 DeleteObjects +type deleteObjectsResult struct { Errors []deleteError `xml:"Error"` } @@ -99,8 +133,8 @@ type deleteError struct { Message string `xml:"Message"` } -// MultipartUpload represents an in-progress multipart upload returned by listMultipartUploads. -type MultipartUpload struct { +// multipartUpload represents an in-progress multipart upload returned by listMultipartUploads. +type multipartUpload struct { Key string UploadID string Initiated time.Time @@ -108,10 +142,10 @@ type MultipartUpload struct { // listMultipartUploadsResult is the XML response from S3 listMultipartUploads type listMultipartUploadsResult struct { - Uploads []listUpload `xml:"Upload"` - IsTruncated bool `xml:"IsTruncated"` - NextKeyMarker string `xml:"NextKeyMarker"` - NextUploadIDMarker string `xml:"NextUploadIdMarker"` + Uploads []*listUpload `xml:"Upload"` + IsTruncated bool `xml:"IsTruncated"` + NextKeyMarker string `xml:"NextKeyMarker"` + NextUploadIDMarker string `xml:"NextUploadIdMarker"` } type listUpload struct { diff --git a/s3/util.go b/s3/util.go index 546a940a..06f7e3d1 100644 --- a/s3/util.go +++ b/s3/util.go @@ -20,8 +20,8 @@ const ( // Sent as the payload hash for streaming uploads where the body is not buffered in memory unsignedPayload = "UNSIGNED-PAYLOAD" - // maxResponseBytes caps the size of S3 response bodies we read into memory (10 MB) - maxResponseBytes = 10 * 1024 * 1024 + // maxResponseBytes caps the size of S3 response bodies we read into memory + maxResponseBytes = 2 * 1024 * 1024 // partSize is the size of each part for multipart uploads (5 MB). This is also the threshold // above which PutObject switches from a simple PUT to multipart upload. S3 requires a minimum @@ -29,7 +29,7 @@ const ( partSize = 5 * 1024 * 1024 // maxPages is the max number of pages to iterate through when listing objects - maxPages = 10000 + maxPages = 500 ) // ParseURL parses an S3 URL of the form: @@ -88,7 +88,7 @@ func ParseURL(s3URL string) (*Config, error) { func parseError(resp *http.Response) error { body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) if err != nil { - return fmt.Errorf("s3: reading error response: %w", err) + return fmt.Errorf("error reading S3 error response: %w", err) } return parseErrorFromBytes(resp.StatusCode, body) } diff --git a/tools/s3cli/main.go b/tools/s3cli/main.go index 1dbac0cf..0e640823 100644 --- a/tools/s3cli/main.go +++ b/tools/s3cli/main.go @@ -105,7 +105,7 @@ func cmdRm(ctx context.Context, client *s3.Client) { } func cmdLs(ctx context.Context, client *s3.Client) { - objects, err := client.ListAllObjects(ctx) + objects, err := client.ListObjectsV2(ctx) if err != nil { fail("ls: %s", err) } From 393f730d11f78d7b2920f707d215b3e5a68be3fe Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 16:12:46 -0400 Subject: [PATCH 12/32] More manual review and refinement --- s3/client.go | 34 +++++++++------------- s3/client_multipart.go | 64 +++++++++++++++++++++--------------------- s3/client_test.go | 10 +++---- s3/types.go | 31 ++++++++------------ 4 files changed, 62 insertions(+), 77 deletions(-) diff --git a/s3/client.go b/s3/client.go index 754a1bfb..83d8195e 100644 --- a/s3/client.go +++ b/s3/client.go @@ -118,9 +118,16 @@ func (c *Client) ListObjectsV2(ctx context.Context) ([]*Object, error) { if err != nil { return nil, err } - for _, obj := range result.Objects { - obj.Key = c.config.StripPrefix(obj.Key) - all = append(all, obj) + for _, obj := range result.Contents { + var lastModified time.Time + if obj.LastModified != "" { + lastModified, _ = time.Parse(time.RFC3339, obj.LastModified) + } + all = append(all, &Object{ + Key: c.config.StripPrefix(obj.Key), + Size: obj.Size, + LastModified: lastModified, + }) } if !result.IsTruncated { return all, nil @@ -148,27 +155,11 @@ func (c *Client) listObjectsV2(ctx context.Context, continuationToken string, ma if err != nil { return nil, err } - var result listObjectsV2Response + var result listObjectsV2Result if err := xml.Unmarshal(respBody, &result); err != nil { return nil, fmt.Errorf("failed to unmarshal list object response: %w", err) } - objects := make([]*Object, len(result.Contents)) - for i, obj := range result.Contents { - var lastModified time.Time - if obj.LastModified != "" { - lastModified, _ = time.Parse(time.RFC3339, obj.LastModified) - } - objects[i] = &Object{ - Key: obj.Key, - Size: obj.Size, - LastModified: lastModified, - } - } - return &listObjectsV2Result{ - Objects: objects, - IsTruncated: result.IsTruncated, - NextContinuationToken: result.NextContinuationToken, - }, nil + return &result, nil } // DeleteObjects removes multiple objects in a single batch request. Keys are automatically @@ -222,6 +213,7 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { // If body is nil, the request is sent with an empty payload. If body is non-nil, it is sent // with a computed SHA-256 payload hash and Content-Type: application/xml. func (c *Client) do(ctx context.Context, op, method, reqURL string, body []byte, headers map[string]string) ([]byte, error) { + log.Tag(tagS3Client).Trace("Performing request %s %s %s (body: %d bytes)", op, method, reqURL, len(body)) var reader io.Reader var hash string if body != nil { diff --git a/s3/client_multipart.go b/s3/client_multipart.go index f6b68784..5e98db38 100644 --- a/s3/client_multipart.go +++ b/s3/client_multipart.go @@ -26,7 +26,6 @@ func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) e } for _, u := range uploads { if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) { - log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated) c.abortMultipartUpload(ctx, u.Key, u.UploadID) } } @@ -47,13 +46,13 @@ func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload, query.Set("key-marker", keyMarker) query.Set("upload-id-marker", uploadIDMarker) } - respBody, err := c.do(ctx, "listMultipartUploads", http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil) + respBody, err := c.do(ctx, "ListMultipartUploads", http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil) if err != nil { return nil, err } var result listMultipartUploadsResult if err := xml.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("s3: listMultipartUploads XML: %w", err) + return nil, fmt.Errorf("error unmarshalling multipart upload result: %w", err) } for _, u := range result.Uploads { var initiated time.Time @@ -75,6 +74,22 @@ func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload, return nil, fmt.Errorf("s3: listMultipartUploads exceeded %d pages", maxPages) } +// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. +func (c *Client) abortMultipartUpload(ctx context.Context, key, uploadID string) { + log.Tag(tagS3Client).Info("Aborting multipart upload for object %s", key) + reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) + if err != nil { + return + } + c.signV4(req, emptyPayloadHash) + resp, err := c.http.Do(req) + if err != nil { + return + } + resp.Body.Close() +} + // putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize // chunks, uploading each as a separate part. This allows uploading without knowing the total // body size in advance. @@ -88,9 +103,9 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea } // Step 2: Upload parts - var parts []completedPart - buf := make([]byte, partSize) partNumber := 1 + buf := make([]byte, partSize) + var parts []*completedPart for { n, err := io.ReadFull(body, buf) if n > 0 { @@ -99,7 +114,10 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea c.abortMultipartUpload(ctx, key, uploadID) return uploadErr } - parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag}) + parts = append(parts, &completedPart{ + PartNumber: partNumber, + ETag: etag, + }) partNumber++ } if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { @@ -123,38 +141,36 @@ func (c *Client) initiateMultipartUpload(ctx context.Context, key string) (strin } var result initiateMultipartUploadResult if err := xml.Unmarshal(respBody, &result); err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) + return "", fmt.Errorf("error unmarshalling initiate multipart upload response: %w", err) } - log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", key, result.UploadID) return result.UploadID, nil } // uploadPart uploads a single part of a multipart upload and returns the ETag. func (c *Client) uploadPart(ctx context.Context, key, uploadID string, partNumber int, data []byte) (string, error) { - log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", key, partNumber, len(data)) + log.Tag(tagS3Client).Debug("Uploading multipart part for object %s, part %d, size %d", key, partNumber, len(data)) reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.config.ObjectURL(key), partNumber, url.QueryEscape(uploadID)) req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) if err != nil { - return "", fmt.Errorf("s3: UploadPart request: %w", err) + return "", fmt.Errorf("error creating multipart upload part request for object %s: %w", key, err) } req.ContentLength = int64(len(data)) c.signV4(req, unsignedPayload) resp, err := c.http.Do(req) if err != nil { - return "", fmt.Errorf("s3: UploadPart: %w", err) + return "", fmt.Errorf("error uploading multipart part for object %s: %w", key, err) } defer resp.Body.Close() if !isHTTPSuccess(resp) { return "", parseError(resp) } - etag := resp.Header.Get("ETag") - return etag, nil + return resp.Header.Get("ETag"), nil } // completeMultipartUpload finalizes a multipart upload with the given parts. -func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID string, parts []completedPart) error { - log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", key, uploadID, len(parts)) - bodyBytes, err := xml.Marshal(completeMultipartUploadRequest{Parts: parts}) +func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID string, parts []*completedPart) error { + log.Tag(tagS3Client).Debug("Completing multipart upload for object %s, %d parts", key, len(parts)) + bodyBytes, err := xml.Marshal(&completeMultipartUploadRequest{Parts: parts}) if err != nil { return fmt.Errorf("s3: CompleteMultipartUpload marshal: %w", err) } @@ -170,19 +186,3 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri } return nil } - -// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. -func (c *Client) abortMultipartUpload(ctx context.Context, key, uploadID string) { - log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", key, uploadID) - reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID)) - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) - if err != nil { - return - } - c.signV4(req, emptyPayloadHash) - resp, err := c.http.Do(req) - if err != nil { - return - } - resp.Body.Close() -} diff --git a/s3/client_test.go b/s3/client_test.go index d488c832..d267a6a8 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -514,13 +514,13 @@ func TestClient_ListObjects(t *testing.T) { // List with prefix client: should only see 3 result, err := client.listObjectsV2(ctx, "", 0) require.Nil(t, err) - require.Len(t, result.Objects, 3) + require.Len(t, result.Contents, 3) require.False(t, result.IsTruncated) // List with no-prefix client: should see all 4 result, err = clientNoPrefix.listObjectsV2(ctx, "", 0) require.Nil(t, err) - require.Len(t, result.Objects, 4) + require.Len(t, result.Contents, 4) } func TestClient_ListObjects_Pagination(t *testing.T) { @@ -539,20 +539,20 @@ func TestClient_ListObjects_Pagination(t *testing.T) { // List with max-keys=2 result, err := client.listObjectsV2(ctx, "", 2) require.Nil(t, err) - require.Len(t, result.Objects, 2) + require.Len(t, result.Contents, 2) require.True(t, result.IsTruncated) require.NotEmpty(t, result.NextContinuationToken) // Get next page result2, err := client.listObjectsV2(ctx, result.NextContinuationToken, 2) require.Nil(t, err) - require.Len(t, result2.Objects, 2) + require.Len(t, result2.Contents, 2) require.True(t, result2.IsTruncated) // Get last page result3, err := client.listObjectsV2(ctx, result2.NextContinuationToken, 2) require.Nil(t, err) - require.Len(t, result3.Objects, 1) + require.Len(t, result3.Contents, 1) require.False(t, result3.IsTruncated) } diff --git a/s3/types.go b/s3/types.go index a3694bd4..1782b88d 100644 --- a/s3/types.go +++ b/s3/types.go @@ -91,20 +91,13 @@ func (e *ErrorResponse) Error() string { return fmt.Sprintf("s3: HTTP %d: %s", e.StatusCode, e.Body) } -// listObjectsV2Response is the XML response from S3 ListObjectsV2 -type listObjectsV2Response struct { +// listObjectsV2Result is the XML response from S3 ListObjectsV2 +type listObjectsV2Result struct { Contents []*listObject `xml:"Contents"` IsTruncated bool `xml:"IsTruncated"` NextContinuationToken string `xml:"NextContinuationToken"` } -// listObjectsV2Result holds the response from a single ListObjectsV2 page. -type listObjectsV2Result struct { - Objects []*Object - IsTruncated bool - NextContinuationToken string -} - type listObject struct { Key string `xml:"Key"` Size int64 `xml:"Size"` @@ -124,7 +117,7 @@ type deleteObject struct { // deleteObjectsResult is the XML response from S3 DeleteObjects type deleteObjectsResult struct { - Errors []deleteError `xml:"Error"` + Errors []*deleteError `xml:"Error"` } type deleteError struct { @@ -133,13 +126,6 @@ type deleteError struct { Message string `xml:"Message"` } -// multipartUpload represents an in-progress multipart upload returned by listMultipartUploads. -type multipartUpload struct { - Key string - UploadID string - Initiated time.Time -} - // listMultipartUploadsResult is the XML response from S3 listMultipartUploads type listMultipartUploadsResult struct { Uploads []*listUpload `xml:"Upload"` @@ -154,6 +140,13 @@ type listUpload struct { Initiated string `xml:"Initiated"` } +// multipartUpload represents an in-progress multipart upload returned by listMultipartUploads. +type multipartUpload struct { + Key string + UploadID string + Initiated time.Time +} + // initiateMultipartUploadResult is the XML response from S3 InitiateMultipartUpload type initiateMultipartUploadResult struct { UploadID string `xml:"UploadId"` @@ -161,8 +154,8 @@ type initiateMultipartUploadResult struct { // completeMultipartUploadRequest is the XML request body for S3 CompleteMultipartUpload type completeMultipartUploadRequest struct { - XMLName xml.Name `xml:"CompleteMultipartUpload"` - Parts []completedPart `xml:"Part"` + XMLName xml.Name `xml:"CompleteMultipartUpload"` + Parts []*completedPart `xml:"Part"` } // completedPart represents a successfully uploaded part for CompleteMultipartUpload From 1742302f83e1a0b3963924bc354d75ace228793e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 16:27:41 -0400 Subject: [PATCH 13/32] More tests and human review --- attachment/store.go | 2 + s3/client_multipart.go | 11 +++-- s3/client_test.go | 4 +- s3/types.go | 6 +-- s3/util.go | 4 +- s3/util_test.go | 4 +- util/limit_test.go | 100 ++++++++++++++++++++++++++++++++++++++++- 7 files changed, 115 insertions(+), 16 deletions(-) diff --git a/attachment/store.go b/attachment/store.go index 0192b09a..10dcd17b 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -110,9 +110,11 @@ func (c *Store) Remove(ids ...string) error { return errInvalidFileID } } + // Remove from backend if err := c.backend.Delete(ids...); err != nil { return err } + // Update total cache size c.mu.Lock() for _, id := range ids { if size, ok := c.sizes[id]; ok { diff --git a/s3/client_multipart.go b/s3/client_multipart.go index 5e98db38..198175d4 100644 --- a/s3/client_multipart.go +++ b/s3/client_multipart.go @@ -71,7 +71,7 @@ func (c *Client) listMultipartUploads(ctx context.Context) ([]*multipartUpload, keyMarker = result.NextKeyMarker uploadIDMarker = result.NextUploadIDMarker } - return nil, fmt.Errorf("s3: listMultipartUploads exceeded %d pages", maxPages) + return nil, fmt.Errorf("error listing multipart uploads, exceeded %d pages", maxPages) } // abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. @@ -122,10 +122,9 @@ func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Rea } if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { break - } - if err != nil { + } else if err != nil { c.abortMultipartUpload(ctx, key, uploadID) - return fmt.Errorf("s3: PutObject read: %w", err) + return fmt.Errorf("error uploading object %s, reading from client failed: %w", key, err) } } @@ -172,7 +171,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri log.Tag(tagS3Client).Debug("Completing multipart upload for object %s, %d parts", key, len(parts)) bodyBytes, err := xml.Marshal(&completeMultipartUploadRequest{Parts: parts}) if err != nil { - return fmt.Errorf("s3: CompleteMultipartUpload marshal: %w", err) + return fmt.Errorf("error marshalling complete multipart upload request: %w", err) } reqURL := fmt.Sprintf("%s?uploadId=%s", c.config.ObjectURL(key), url.QueryEscape(uploadID)) respBody, err := c.do(ctx, "CompleteMultipartUpload", http.MethodPost, reqURL, bodyBytes, nil) @@ -180,7 +179,7 @@ func (c *Client) completeMultipartUpload(ctx context.Context, key, uploadID stri return err } // Check if the response contains an error (S3 can return 200 with an error body) - var errResp ErrorResponse + var errResp errorResponse if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" { return &errResp } diff --git a/s3/client_test.go b/s3/client_test.go index d267a6a8..84402831 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -456,7 +456,7 @@ func TestClient_GetObject_NotFound(t *testing.T) { _, _, err := client.GetObject(context.Background(), "nonexistent") require.Error(t, err) - var errResp *ErrorResponse + var errResp *errorResponse require.ErrorAs(t, err, &errResp) require.Equal(t, 404, errResp.StatusCode) require.Equal(t, "NoSuchKey", errResp.Code) @@ -799,7 +799,7 @@ func TestClient_RealBucket(t *testing.T) { // Get after delete should fail _, _, err = client.GetObject(ctx, key) require.Error(t, err) - var errResp *ErrorResponse + var errResp *errorResponse require.ErrorAs(t, err, &errResp) require.Equal(t, 404, errResp.StatusCode) }) diff --git a/s3/types.go b/s3/types.go index 1782b88d..96b62649 100644 --- a/s3/types.go +++ b/s3/types.go @@ -76,15 +76,15 @@ type Object struct { LastModified time.Time } -// ErrorResponse is returned when S3 responds with a non-2xx status code. -type ErrorResponse struct { +// errorResponse is returned when S3 responds with a non-2xx status code. +type errorResponse struct { StatusCode int Code string `xml:"Code"` Message string `xml:"Message"` Body string `xml:"-"` // raw response body } -func (e *ErrorResponse) Error() string { +func (e *errorResponse) Error() string { if e.Code != "" { return fmt.Sprintf("s3: %s (HTTP %d): %s", e.Code, e.StatusCode, e.Message) } diff --git a/s3/util.go b/s3/util.go index 06f7e3d1..0bcc96d2 100644 --- a/s3/util.go +++ b/s3/util.go @@ -84,7 +84,7 @@ func ParseURL(s3URL string) (*Config, error) { }, nil } -// parseError reads an S3 error response and returns an *ErrorResponse. +// parseError reads an S3 error response and returns an *errorResponse. func parseError(resp *http.Response) error { body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) if err != nil { @@ -94,7 +94,7 @@ func parseError(resp *http.Response) error { } func parseErrorFromBytes(statusCode int, body []byte) error { - errResp := &ErrorResponse{ + errResp := &errorResponse{ StatusCode: statusCode, Body: string(body), } diff --git a/s3/util_test.go b/s3/util_test.go index 3f08911d..93ddd707 100644 --- a/s3/util_test.go +++ b/s3/util_test.go @@ -163,7 +163,7 @@ func TestParseError_XMLResponse(t *testing.T) { xmlBody := []byte(`NoSuchKeyThe specified key does not exist.`) err := parseErrorFromBytes(404, xmlBody) - var errResp *ErrorResponse + var errResp *errorResponse require.ErrorAs(t, err, &errResp) require.Equal(t, 404, errResp.StatusCode) require.Equal(t, "NoSuchKey", errResp.Code) @@ -173,7 +173,7 @@ func TestParseError_XMLResponse(t *testing.T) { func TestParseError_NonXMLResponse(t *testing.T) { err := parseErrorFromBytes(500, []byte("internal server error")) - var errResp *ErrorResponse + var errResp *errorResponse require.ErrorAs(t, err, &errResp) require.Equal(t, 500, errResp.StatusCode) require.Equal(t, "", errResp.Code) // XML parsing failed, no code diff --git a/util/limit_test.go b/util/limit_test.go index 51595351..9ca9fe39 100644 --- a/util/limit_test.go +++ b/util/limit_test.go @@ -2,9 +2,12 @@ package util import ( "bytes" - "github.com/stretchr/testify/require" + "io" + "strings" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestFixedLimiter_AllowValueReset(t *testing.T) { @@ -147,3 +150,98 @@ func TestLimitWriter_WriteTwoDifferentLimiters_Wait_FixedLimiterFail(t *testing. _, err = lw.Write(make([]byte, 8)) // <<< FixedLimiter fails require.Equal(t, ErrLimitReached, err) } + +func TestCountingReader_Total(t *testing.T) { + cr := NewCountingReader(strings.NewReader("hello world")) + buf := make([]byte, 5) + + n, err := cr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + require.Equal(t, int64(5), cr.Total()) + + n, err = cr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + require.Equal(t, int64(10), cr.Total()) + + n, err = cr.Read(buf) + require.Nil(t, err) + require.Equal(t, 1, n) + require.Equal(t, int64(11), cr.Total()) + + _, err = cr.Read(buf) + require.Equal(t, io.EOF, err) + require.Equal(t, int64(11), cr.Total()) +} + +func TestCountingReader_Empty(t *testing.T) { + cr := NewCountingReader(strings.NewReader("")) + require.Equal(t, int64(0), cr.Total()) + + _, err := cr.Read(make([]byte, 10)) + require.Equal(t, io.EOF, err) + require.Equal(t, int64(0), cr.Total()) +} + +func TestLimitReader_ReadNoLimiter(t *testing.T) { + lr := NewLimitReader(strings.NewReader("hello")) + data, err := io.ReadAll(lr) + require.Nil(t, err) + require.Equal(t, "hello", string(data)) +} + +func TestLimitReader_ReadOneLimiter(t *testing.T) { + l := NewFixedLimiter(10) + lr := NewLimitReader(strings.NewReader("hello world!"), l) + + buf := make([]byte, 5) + n, err := lr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + require.Equal(t, int64(5), l.Value()) + + n, err = lr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + require.Equal(t, int64(10), l.Value()) + + _, err = lr.Read(buf) + require.Equal(t, ErrLimitReached, err) +} + +func TestLimitReader_ReadTwoLimiters(t *testing.T) { + l1 := NewFixedLimiter(11) + l2 := NewFixedLimiter(8) + lr := NewLimitReader(strings.NewReader("hello world!"), l1, l2) + + buf := make([]byte, 5) + n, err := lr.Read(buf) + require.Nil(t, err) + require.Equal(t, 5, n) + + // Second read: l2 (limit 8) should reject 5 more bytes + _, err = lr.Read(buf) + require.Equal(t, ErrLimitReached, err) + // l1 should have been reverted + require.Equal(t, int64(5), l1.Value()) + require.Equal(t, int64(5), l2.Value()) +} + +func TestLimitReader_ReadAll(t *testing.T) { + l := NewFixedLimiter(100) + lr := NewLimitReader(strings.NewReader("hello"), l) + data, err := io.ReadAll(lr) + require.Nil(t, err) + require.Equal(t, "hello", string(data)) + require.Equal(t, int64(5), l.Value()) +} + +func TestLimitReader_ReadExactLimit(t *testing.T) { + l := NewFixedLimiter(5) + lr := NewLimitReader(bytes.NewReader(make([]byte, 5)), l) + data, err := io.ReadAll(lr) + require.Nil(t, err) + require.Equal(t, 5, len(data)) + require.Equal(t, int64(5), l.Value()) +} From 6a820b503046dd3e2838b874631ca7b7dbd8bec2 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 16:29:58 -0400 Subject: [PATCH 14/32] Tags --- attachment/backend_file.go | 2 +- attachment/backend_s3.go | 2 +- attachment/store.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/attachment/backend_file.go b/attachment/backend_file.go index 8aaf20b9..260236d1 100644 --- a/attachment/backend_file.go +++ b/attachment/backend_file.go @@ -9,7 +9,7 @@ import ( "heckel.io/ntfy/v2/log" ) -const tagFileBackend = "file_backend" +const tagFileBackend = "attachment_file" type fileBackend struct { dir string diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index eb911edc..61d1c7b1 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -10,7 +10,7 @@ import ( ) const ( - tagS3Backend = "s3_backend" + tagS3Backend = "attachment_s3" deleteBatchSize = 1000 ) diff --git a/attachment/store.go b/attachment/store.go index 10dcd17b..8059a0cf 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -15,7 +15,7 @@ import ( ) const ( - tagStore = "attachment_cache" + tagStore = "attachment_store" syncInterval = 15 * time.Minute // How often to run the background sync loop orphanGracePeriod = time.Hour // Don't delete orphaned objects younger than this to avoid races with in-flight uploads ) From 78d3138565ba17f05f74a04af5f536700b85586d Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 16:54:16 -0400 Subject: [PATCH 15/32] Fix flaky test --- server/server_webpush_test.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/server/server_webpush_test.go b/server/server_webpush_test.go index 047c8708..bba13db4 100644 --- a/server/server_webpush_test.go +++ b/server/server_webpush_test.go @@ -235,13 +235,12 @@ func TestServer_WebPush_Publish_RemoveOnError(t *testing.T) { request(t, s, "POST", "/test-topic", "web push test", nil) - waitFor(t, func() bool { - return received.Load() - }) - // Receiving the 410 should've caused the publisher to expire all subscriptions on the endpoint - - requireSubscriptionCount(t, s, "test-topic", 0) + waitFor(t, func() bool { + subs, err := s.webPush.SubscriptionsForTopic("test-topic") + require.Nil(t, err) + return len(subs) == 0 + }) requireSubscriptionCount(t, s, "test-topic-abc", 0) }) } From b3a8f18019e3744c4ddd09af392318dc1756b3ed Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 17:03:29 -0400 Subject: [PATCH 16/32] Docs --- attachment/store.go | 2 +- docs/config.md | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/attachment/store.go b/attachment/store.go index 8059a0cf..ba2e22cc 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -162,7 +162,7 @@ func (c *Store) sync() error { sizes[obj.ID] = obj.Size } } - log.Tag(tagStore).Debug("Attachment cache size updated to %s", util.FormatSizeHuman(size)) + log.Tag(tagStore).Debug("Attachment store updated: %d attachment(s), %s", len(localIDs), util.FormatSizeHuman(size)) c.mu.Lock() c.size = size c.sizes = sizes diff --git a/docs/config.md b/docs/config.md index edfa43ff..853d4891 100644 --- a/docs/config.md +++ b/docs/config.md @@ -547,6 +547,11 @@ When `endpoint` is specified, path-style addressing is enabled automatically (us attachment-cache-dir: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com" ``` +**Cleanup behavior:** A background sync runs every 15 minutes to reconcile the S3 bucket (or configured prefix) with +the server's message database. Objects whose keys match attachment file IDs that are no longer referenced in the database +(and are older than 1 hour) are automatically deleted. This also cleans up incomplete S3 multipart uploads that were +abandoned due to interrupted or failed attachment uploads. + Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit` and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse. From b81218953a9ac89b57ad7e86e6d97de94c67b46e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 21:14:49 -0400 Subject: [PATCH 17/32] Allow streaming to S3 --- attachment/backend.go | 2 +- attachment/backend_file.go | 12 ++++- attachment/backend_s3.go | 4 +- attachment/store.go | 14 +++--- attachment/store_file_test.go | 46 ++++++++++++++---- attachment/store_s3_test.go | 22 ++++----- s3/client.go | 29 ++++++++---- s3/client_test.go | 87 +++++++++++++++++++++++++++++------ s3/util.go | 4 ++ server/server.go | 19 ++++---- server/server_test.go | 4 +- tools/s3cli/main.go | 8 +++- 12 files changed, 181 insertions(+), 70 deletions(-) diff --git a/attachment/backend.go b/attachment/backend.go index e95fc91e..921ceb3e 100644 --- a/attachment/backend.go +++ b/attachment/backend.go @@ -15,7 +15,7 @@ type object struct { // backend is a minimal I/O interface for storing and retrieving attachment files. // It has no knowledge of size tracking, limiting, or ID validation. type backend interface { - Put(id string, in io.Reader) error + Put(id string, reader io.Reader, untrustedLength int64) error Get(id string) (io.ReadCloser, int64, error) List() ([]object, error) Delete(ids ...string) error diff --git a/attachment/backend_file.go b/attachment/backend_file.go index 260236d1..8726ddf4 100644 --- a/attachment/backend_file.go +++ b/attachment/backend_file.go @@ -1,6 +1,7 @@ package attachment import ( + "fmt" "io" "os" "path/filepath" @@ -24,16 +25,23 @@ func newFileBackend(dir string) (*fileBackend, error) { return &fileBackend{dir: dir}, nil } -func (b *fileBackend) Put(id string, in io.Reader) error { +func (b *fileBackend) Put(id string, reader io.Reader, untrustedLength int64) error { + if untrustedLength > 0 { + reader = io.LimitReader(reader, untrustedLength) + } file := filepath.Join(b.dir, id) f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) if err != nil { return err } defer f.Close() - if _, err := io.Copy(f, in); err != nil { + n, err := io.Copy(f, reader) + if err != nil { os.Remove(file) return err + } else if untrustedLength > 0 && n != untrustedLength { + os.Remove(file) + return fmt.Errorf("content length mismatch: claimed %d, got %d", untrustedLength, n) } if err := f.Close(); err != nil { os.Remove(file) diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index 61d1c7b1..081d6002 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -24,8 +24,8 @@ func newS3Backend(client *s3.Client) *s3Backend { return &s3Backend{client: client} } -func (b *s3Backend) Put(id string, in io.Reader) error { - return b.client.PutObject(context.Background(), id, in) +func (b *s3Backend) Put(id string, reader io.Reader, untrustedLength int64) error { + return b.client.PutObject(context.Background(), id, reader, untrustedLength) } func (b *s3Backend) Get(id string) (io.ReadCloser, int64, error) { diff --git a/attachment/store.go b/attachment/store.go index ba2e22cc..6e7cfb99 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -72,20 +72,22 @@ func newStore(backend backend, totalSizeLimit int64, localIDs func() ([]string, } // Write stores an attachment file. The id is validated, and the write is subject to -// the total size limit and any additional limiters. -func (c *Store) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) { +// the total size limit and any additional limiters. The untrustedLength is a hint +// from the client's Content-Length header; backends may use it to optimize uploads (e.g. +// streaming directly to S3 without buffering). +func (c *Store) Write(id string, reader io.Reader, untrustedLength int64, limiters ...util.Limiter) (int64, error) { if !fileIDRegex.MatchString(id) { return 0, errInvalidFileID } log.Tag(tagStore).Field("message_id", id).Debug("Writing attachment") limiters = append(limiters, util.NewFixedLimiter(c.Remaining())) - cr := util.NewCountingReader(in) - lr := util.NewLimitReader(cr, limiters...) - if err := c.backend.Put(id, lr); err != nil { + countingReader := util.NewCountingReader(reader) + limitReader := util.NewLimitReader(countingReader, limiters...) + if err := c.backend.Put(id, limitReader, untrustedLength); err != nil { c.backend.Delete(id) //nolint:errcheck return 0, err } - size := cr.Total() + size := countingReader.Total() c.mu.Lock() c.size += size c.sizes[id] = size diff --git a/attachment/store_file_test.go b/attachment/store_file_test.go index c65bad92..998a2bea 100644 --- a/attachment/store_file_test.go +++ b/attachment/store_file_test.go @@ -19,7 +19,7 @@ var ( func TestFileStore_Write_Success(t *testing.T) { dir, c := newTestFileStore(t) - size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), util.NewFixedLimiter(999)) + size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), 0, util.NewFixedLimiter(999)) require.Nil(t, err) require.Equal(t, int64(11), size) require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) @@ -29,7 +29,7 @@ func TestFileStore_Write_Success(t *testing.T) { func TestFileStore_Write_Read_Success(t *testing.T) { _, c := newTestFileStore(t) - size, err := c.Write("abcdefghijkl", strings.NewReader("hello world")) + size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 0) require.Nil(t, err) require.Equal(t, int64(11), size) @@ -45,7 +45,7 @@ func TestFileStore_Write_Read_Success(t *testing.T) { func TestFileStore_Write_Remove_Success(t *testing.T) { dir, c := newTestFileStore(t) // max = 10k (10240), each = 1k (1024) for i := 0; i < 10; i++ { // 10x999 = 9990 - size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999))) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)), 0) require.Nil(t, err) require.Equal(t, int64(999), size) } @@ -64,22 +64,48 @@ func TestFileStore_Write_Remove_Success(t *testing.T) { func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) { dir, c := newTestFileStore(t) for i := 0; i < 10; i++ { - size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray)) + size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray), 0) require.Nil(t, err) require.Equal(t, int64(1024), size) } - _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray)) + _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray), 0) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abcdefghijkX") } func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) { dir, c := newTestFileStore(t) - _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), util.NewFixedLimiter(1000)) + _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), 0, util.NewFixedLimiter(1000)) require.Equal(t, util.ErrLimitReached, err) require.NoFileExists(t, dir+"/abcdefghijkl") } +func TestFileStore_Write_UntrustedContentLengthExact(t *testing.T) { + dir, c := newTestFileStore(t) + size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 11) + require.Nil(t, err) + require.Equal(t, int64(11), size) + require.Equal(t, "hello world", readFile(t, dir+"/abcdefghijkl")) +} + +func TestFileStore_Write_UntrustedContentLengthBodyLonger(t *testing.T) { + dir, c := newTestFileStore(t) + // Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored + size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 5) + require.Nil(t, err) + require.Equal(t, int64(5), size) + require.Equal(t, "hello", readFile(t, dir+"/abcdefghijkl")) +} + +func TestFileStore_Write_UntrustedContentLengthBodyShorter(t *testing.T) { + dir, c := newTestFileStore(t) + // Body has 5 bytes, but we claim 100 — should fail with content length mismatch + _, err := c.Write("abcdefghijkl", strings.NewReader("hello"), 100) + require.Error(t, err) + require.Contains(t, err.Error(), "content length mismatch") + require.NoFileExists(t, dir+"/abcdefghijkl") +} + func TestFileStore_Read_NotFound(t *testing.T) { _, c := newTestFileStore(t) _, _, err := c.Read("abcdefghijkl") @@ -90,11 +116,11 @@ func TestFileStore_Sync(t *testing.T) { dir, c := newTestFileStore(t) // Write some files - _, err := c.Write("abcdefghijk0", strings.NewReader("file0")) + _, err := c.Write("abcdefghijk0", strings.NewReader("file0"), 0) require.Nil(t, err) - _, err = c.Write("abcdefghijk1", strings.NewReader("file1")) + _, err = c.Write("abcdefghijk1", strings.NewReader("file1"), 0) require.Nil(t, err) - _, err = c.Write("abcdefghijk2", strings.NewReader("file2")) + _, err = c.Write("abcdefghijk2", strings.NewReader("file2"), 0) require.Nil(t, err) require.Equal(t, int64(15), c.Size()) @@ -124,7 +150,7 @@ func TestFileStore_Sync_SkipsRecentFiles(t *testing.T) { dir, c := newTestFileStore(t) // Write a file - _, err := c.Write("abcdefghijk0", strings.NewReader("file0")) + _, err := c.Write("abcdefghijk0", strings.NewReader("file0"), 0) require.Nil(t, err) // Set the ID provider to return empty (no valid IDs) diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 3ad5a93c..37bd0ecb 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -26,7 +26,7 @@ func TestS3Store_WriteReadRemove(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) // Write - size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world")) + size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world"), 0) require.Nil(t, err) require.Equal(t, int64(11), size) require.Equal(t, int64(11), cache.Size()) @@ -55,7 +55,7 @@ func TestS3Store_WriteNoPrefix(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "", 10*1024) - size, err := cache.Write("abcdefghijkl", strings.NewReader("test")) + size, err := cache.Write("abcdefghijkl", strings.NewReader("test"), 0) require.Nil(t, err) require.Equal(t, int64(4), size) @@ -74,13 +74,13 @@ func TestS3Store_WriteTotalSizeLimit(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 100) // First write fits - _, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80))) + _, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)), 0) require.Nil(t, err) require.Equal(t, int64(80), cache.Size()) require.Equal(t, int64(20), cache.Remaining()) // Second write exceeds total limit - _, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50))) + _, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)), 0) require.ErrorIs(t, err, util.ErrLimitReached) } @@ -90,7 +90,7 @@ func TestS3Store_WriteFileSizeLimit(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), util.NewFixedLimiter(100)) + _, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), 0, util.NewFixedLimiter(100)) require.ErrorIs(t, err, util.ErrLimitReached) } @@ -101,7 +101,7 @@ func TestS3Store_WriteRemoveMultiple(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) for i := 0; i < 5; i++ { - _, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100))) + _, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)), 0) require.Nil(t, err) } require.Equal(t, int64(500), cache.Size()) @@ -126,7 +126,7 @@ func TestS3Store_InvalidID(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - _, err := cache.Write("bad", strings.NewReader("x")) + _, err := cache.Write("bad", strings.NewReader("x"), 0) require.Equal(t, errInvalidFileID, err) _, _, err = cache.Read("bad") @@ -143,11 +143,11 @@ func TestS3Store_Sync(t *testing.T) { cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) // Write some files - _, err := cache.Write("abcdefghijk0", strings.NewReader("file0")) + _, err := cache.Write("abcdefghijk0", strings.NewReader("file0"), 0) require.Nil(t, err) - _, err = cache.Write("abcdefghijk1", strings.NewReader("file1")) + _, err = cache.Write("abcdefghijk1", strings.NewReader("file1"), 0) require.Nil(t, err) - _, err = cache.Write("abcdefghijk2", strings.NewReader("file2")) + _, err = cache.Write("abcdefghijk2", strings.NewReader("file2"), 0) require.Nil(t, err) require.Equal(t, int64(15), cache.Size()) @@ -175,7 +175,7 @@ func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) { cache := newTestS3Store(t, mockServer, "my-bucket", "pfx", 10*1024) - _, err := cache.Write("abcdefghijk0", strings.NewReader("file0")) + _, err := cache.Write("abcdefghijk0", strings.NewReader("file0"), 0) require.Nil(t, err) // Set the ID provider to return empty (no valid IDs) diff --git a/s3/client.go b/s3/client.go index 83d8195e..29cad3a4 100644 --- a/s3/client.go +++ b/s3/client.go @@ -45,25 +45,36 @@ func New(config *Config) *Client { } // PutObject uploads body to the given key. The key is automatically prefixed with the client's -// configured prefix. The body size does not need to be known in advance. +// configured prefix. // -// If the entire body fits in a single part (5 MB), it is uploaded with a simple PUT request -// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html). Otherwise, the body -// is uploaded using S3 multipart upload, reading one part at a time into memory -// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html). -func (c *Client) PutObject(ctx context.Context, key string, body io.Reader) error { +// If untrustedLength is between 1 and 5 GB, the body is streamed directly to S3 via a +// single PUT request without buffering. The read is limited to untrustedLength bytes; +// any extra data in the body is ignored. If the body is shorter than claimed, the upload fails. +// +// Otherwise (untrustedLength <= 0 or > 5 GB), the first 5 MB are buffered to decide +// between a simple PUT and multipart upload. +// +// See https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html +// and https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html +func (c *Client) PutObject(ctx context.Context, key string, body io.Reader, untrustedLength int64) error { + if untrustedLength > 0 && untrustedLength <= maxSinglePutSize { + // Stream directly: Content-Length is known (but untrusted). LimitReader ensures we send at most + // untrustedLength bytes, and any extra data in body is ignored. + return c.putObject(ctx, key, io.LimitReader(body, untrustedLength), untrustedLength) + } + // Buffered path: read first 5 MB to decide simple vs multipart first := make([]byte, partSize) n, err := io.ReadFull(body, first) if errors.Is(err, io.ErrUnexpectedEOF) || err == io.EOF { - return c.putObjectSimple(ctx, key, bytes.NewReader(first[:n]), int64(n)) + return c.putObject(ctx, key, bytes.NewReader(first[:n]), int64(n)) } else if err != nil { return fmt.Errorf("error reading object %s from client: %w", key, err) } return c.putObjectMultipart(ctx, key, io.MultiReader(bytes.NewReader(first), body)) } -// putObjectSimple uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. -func (c *Client) putObjectSimple(ctx context.Context, key string, body io.Reader, size int64) error { +// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. +func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error { log.Tag(tagS3Client).Debug("Uploading object %s (%d bytes)", key, size) req, err := http.NewRequestWithContext(ctx, http.MethodPut, c.config.ObjectURL(key), body) if err != nil { diff --git a/s3/client_test.go b/s3/client_test.go index 84402831..652db3e7 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -419,7 +419,7 @@ func TestClient_PutGetObject(t *testing.T) { ctx := context.Background() // Put - err := client.PutObject(ctx, "test-key", strings.NewReader("hello world")) + err := client.PutObject(ctx, "test-key", strings.NewReader("hello world"), 0) require.Nil(t, err) // Get @@ -439,7 +439,7 @@ func TestClient_PutGetObject_WithPrefix(t *testing.T) { ctx := context.Background() - err := client.PutObject(ctx, "test-key", strings.NewReader("hello")) + err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 0) require.Nil(t, err) reader, _, err := client.GetObject(ctx, "test-key") @@ -471,7 +471,7 @@ func TestClient_DeleteObjects(t *testing.T) { // Put several objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data"))) + err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 0) require.Nil(t, err) } require.Equal(t, 5, mock.objectCount()) @@ -502,13 +502,13 @@ func TestClient_ListObjects(t *testing.T) { // Client with prefix "pfx": list should only return objects under pfx/ client := newTestClient(server, "my-bucket", "pfx") for i := 0; i < 3; i++ { - err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x"))) + err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } // Also put an object outside the prefix using a no-prefix client clientNoPrefix := newTestClient(server, "my-bucket", "") - err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y"))) + err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 0) require.Nil(t, err) // List with prefix client: should only see 3 @@ -532,7 +532,7 @@ func TestClient_ListObjects_Pagination(t *testing.T) { // Put 5 objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x"))) + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } @@ -564,7 +564,7 @@ func TestClient_ListAllObjects(t *testing.T) { ctx := context.Background() for i := 0; i < 10; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x"))) + err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } @@ -585,7 +585,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) { for i := range data { data[i] = byte(i % 256) } - err := client.PutObject(ctx, "large", bytes.NewReader(data)) + err := client.PutObject(ctx, "large", bytes.NewReader(data), 0) require.Nil(t, err) reader, size, err := client.GetObject(ctx, "large") @@ -609,7 +609,7 @@ func TestClient_PutObject_ChunkedUpload(t *testing.T) { for i := range data { data[i] = byte(i % 256) } - err := client.PutObject(ctx, "multipart", bytes.NewReader(data)) + err := client.PutObject(ctx, "multipart", bytes.NewReader(data), 0) require.Nil(t, err) reader, size, err := client.GetObject(ctx, "multipart") @@ -633,7 +633,7 @@ func TestClient_PutObject_ExactPartSize(t *testing.T) { for i := range data { data[i] = byte(i % 256) } - err := client.PutObject(ctx, "exact", bytes.NewReader(data)) + err := client.PutObject(ctx, "exact", bytes.NewReader(data), 0) require.Nil(t, err) reader, size, err := client.GetObject(ctx, "exact") @@ -645,6 +645,63 @@ func TestClient_PutObject_ExactPartSize(t *testing.T) { require.Equal(t, data, got) } +func TestClient_PutObject_StreamingExactLength(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "pfx") + + ctx := context.Background() + + // untrustedLength matches body exactly — streams directly via putObject + err := client.PutObject(ctx, "stream-exact", strings.NewReader("hello world"), 11) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "stream-exact") + require.Nil(t, err) + require.Equal(t, int64(11), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "hello world", string(got)) +} + +func TestClient_PutObject_StreamingBodyLongerThanClaimed(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "pfx") + + ctx := context.Background() + + // Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored + err := client.PutObject(ctx, "stream-long", strings.NewReader("hello world"), 5) + require.Nil(t, err) + + reader, size, err := client.GetObject(ctx, "stream-long") + require.Nil(t, err) + require.Equal(t, int64(5), size) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "hello", string(got)) +} + +func TestClient_PutObject_StreamingBodyShorterThanClaimed(t *testing.T) { + server, _ := newMockS3Server() + defer server.Close() + client := newTestClient(server, "my-bucket", "pfx") + + ctx := context.Background() + + // Body has 5 bytes, but we claim 100 — should fail + err := client.PutObject(ctx, "stream-short", strings.NewReader("hello"), 100) + require.Error(t, err) + require.Contains(t, err.Error(), "ContentLength") + + // Object should not exist + _, _, err = client.GetObject(ctx, "stream-short") + require.Error(t, err) +} + func TestClient_PutObject_NestedKey(t *testing.T) { server, _ := newMockS3Server() defer server.Close() @@ -652,7 +709,7 @@ func TestClient_PutObject_NestedKey(t *testing.T) { ctx := context.Background() - err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested")) + err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 0) require.Nil(t, err) reader, _, err := client.GetObject(ctx, "deep/nested/prefix/file.txt") @@ -682,7 +739,7 @@ func TestClient_ListAllObjects_20k(t *testing.T) { for i := 0; i < batchSize; i++ { idx := batch*batchSize + i key := fmt.Sprintf("%08d", idx) - err := client.PutObject(ctx, key, bytes.NewReader([]byte("x"))) + err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } } @@ -780,7 +837,7 @@ func TestClient_RealBucket(t *testing.T) { content := "hello from ntfy s3 test" // Put - err := client.PutObject(ctx, key, strings.NewReader(content)) + err := client.PutObject(ctx, key, strings.NewReader(content), 0) require.Nil(t, err) // Get @@ -818,7 +875,7 @@ func TestClient_RealBucket(t *testing.T) { // Put 10 objects for i := 0; i < 10; i++ { - err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x")) + err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 0) require.Nil(t, err) } @@ -843,7 +900,7 @@ func TestClient_RealBucket(t *testing.T) { data[i] = byte(i % 256) } - err := client.PutObject(ctx, key, bytes.NewReader(data)) + err := client.PutObject(ctx, key, bytes.NewReader(data), 0) require.Nil(t, err) reader, size, err := client.GetObject(ctx, key) diff --git a/s3/util.go b/s3/util.go index 0bcc96d2..1f4c2dd9 100644 --- a/s3/util.go +++ b/s3/util.go @@ -28,6 +28,10 @@ const ( // part size of 5 MB for all parts except the last. partSize = 5 * 1024 * 1024 + // maxSinglePutSize is the maximum size for a single PUT upload (5 GB). + // Objects larger than this must use multipart upload. + maxSinglePutSize = 5 * 1024 * 1024 * 1024 + // maxPages is the max number of pages to iterate through when listing objects maxPages = 500 ) diff --git a/server/server.go b/server/server.go index 99a61906..87eee5d6 100644 --- a/server/server.go +++ b/server/server.go @@ -1432,16 +1432,13 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Me if m.Time > attachmentExpiry { return errHTTPBadRequestAttachmentsExpiryBeforeDelivery.With(m) } - contentLengthStr := r.Header.Get("Content-Length") - if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below - contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64) - if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) { - return errHTTPEntityTooLargeAttachment.With(m).Fields(log.Context{ - "message_content_length": contentLength, - "attachment_total_size_remaining": vinfo.Stats.AttachmentTotalSizeRemaining, - "attachment_file_size_limit": vinfo.Limits.AttachmentFileSizeLimit, - }) - } + // Early "do-not-trust" check, hard limit see below + if r.ContentLength > 0 && (r.ContentLength > vinfo.Stats.AttachmentTotalSizeRemaining || r.ContentLength > vinfo.Limits.AttachmentFileSizeLimit) { + return errHTTPEntityTooLargeAttachment.With(m).Fields(log.Context{ + "message_content_length": r.ContentLength, + "attachment_total_size_remaining": vinfo.Stats.AttachmentTotalSizeRemaining, + "attachment_file_size_limit": vinfo.Limits.AttachmentFileSizeLimit, + }) } if m.Attachment == nil { m.Attachment = &model.Attachment{} @@ -1461,7 +1458,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Me util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit), util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining), } - m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...) + m.Attachment.Size, err = s.fileCache.Write(m.ID, body, r.ContentLength, limiters...) if errors.Is(err, util.ErrLimitReached) { return errHTTPEntityTooLargeAttachment.With(m) } else if err != nil { diff --git a/server/server_test.go b/server/server_test.go index cb20cbda..449b6006 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2218,8 +2218,8 @@ func TestServer_PublishAttachmentTooLargeContentLength(t *testing.T) { forEachBackend(t, func(t *testing.T, databaseURL string) { content := util.RandomString(5000) // > 4096 s := newTestServer(t, newTestConfig(t, databaseURL)) - response := request(t, s, "PUT", "/mytopic", content, map[string]string{ - "Content-Length": "20000000", + response := request(t, s, "PUT", "/mytopic", content, nil, func(r *http.Request) { + r.ContentLength = 20000000 }) err := toHTTPError(t, response.Body.String()) require.Equal(t, 413, response.Code) diff --git a/tools/s3cli/main.go b/tools/s3cli/main.go index 0e640823..5de8a75c 100644 --- a/tools/s3cli/main.go +++ b/tools/s3cli/main.go @@ -58,6 +58,7 @@ func cmdPut(ctx context.Context, client *s3.Client) { path := os.Args[3] var r io.Reader + var size int64 if path == "-" { r = os.Stdin } else { @@ -66,10 +67,15 @@ func cmdPut(ctx context.Context, client *s3.Client) { fail("open %s: %s", path, err) } defer f.Close() + stat, err := f.Stat() + if err != nil { + fail("stat %s: %s", path, err) + } r = f + size = stat.Size() } - if err := client.PutObject(ctx, key, r); err != nil { + if err := client.PutObject(ctx, key, r, size); err != nil { fail("put: %s", err) } fmt.Fprintf(os.Stderr, "uploaded %s\n", key) From ad501feab1d46f9cfb8460d22c746c0a977c8806 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 21 Mar 2026 21:59:59 -0400 Subject: [PATCH 18/32] Rewrite tests --- attachment/store_file_test.go | 168 +---------------------- attachment/store_s3_test.go | 195 +++----------------------- attachment/store_test.go | 252 ++++++++++++++++++++++++++++++++++ 3 files changed, 275 insertions(+), 340 deletions(-) create mode 100644 attachment/store_test.go diff --git a/attachment/store_file_test.go b/attachment/store_file_test.go index 998a2bea..d0b6e135 100644 --- a/attachment/store_file_test.go +++ b/attachment/store_file_test.go @@ -1,180 +1,16 @@ package attachment import ( - "bytes" - "fmt" - "io" - "os" - "strings" "testing" - "time" "github.com/stretchr/testify/require" - "heckel.io/ntfy/v2/util" ) -var ( - oneKilobyteArray = make([]byte, 1024) -) - -func TestFileStore_Write_Success(t *testing.T) { - dir, c := newTestFileStore(t) - size, err := c.Write("abcdefghijkl", strings.NewReader("normal file"), 0, util.NewFixedLimiter(999)) - require.Nil(t, err) - require.Equal(t, int64(11), size) - require.Equal(t, "normal file", readFile(t, dir+"/abcdefghijkl")) - require.Equal(t, int64(11), c.Size()) - require.Equal(t, int64(10229), c.Remaining()) -} - -func TestFileStore_Write_Read_Success(t *testing.T) { - _, c := newTestFileStore(t) - size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 0) - require.Nil(t, err) - require.Equal(t, int64(11), size) - - reader, readSize, err := c.Read("abcdefghijkl") - require.Nil(t, err) - require.Equal(t, int64(11), readSize) - defer reader.Close() - data, err := io.ReadAll(reader) - require.Nil(t, err) - require.Equal(t, "hello world", string(data)) -} - -func TestFileStore_Write_Remove_Success(t *testing.T) { - dir, c := newTestFileStore(t) // max = 10k (10240), each = 1k (1024) - for i := 0; i < 10; i++ { // 10x999 = 9990 - size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 999)), 0) - require.Nil(t, err) - require.Equal(t, int64(999), size) - } - require.Equal(t, int64(9990), c.Size()) - require.Equal(t, int64(250), c.Remaining()) - require.FileExists(t, dir+"/abcdefghijk1") - require.FileExists(t, dir+"/abcdefghijk5") - - require.Nil(t, c.Remove("abcdefghijk1", "abcdefghijk5")) - require.NoFileExists(t, dir+"/abcdefghijk1") - require.NoFileExists(t, dir+"/abcdefghijk5") - require.Equal(t, int64(8*999), c.Size()) - require.Equal(t, int64(10240-8*999), c.Remaining()) -} - -func TestFileStore_Write_FailedTotalSizeLimit(t *testing.T) { - dir, c := newTestFileStore(t) - for i := 0; i < 10; i++ { - size, err := c.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(oneKilobyteArray), 0) - require.Nil(t, err) - require.Equal(t, int64(1024), size) - } - _, err := c.Write("abcdefghijkX", bytes.NewReader(oneKilobyteArray), 0) - require.Equal(t, util.ErrLimitReached, err) - require.NoFileExists(t, dir+"/abcdefghijkX") -} - -func TestFileStore_Write_FailedAdditionalLimiter(t *testing.T) { - dir, c := newTestFileStore(t) - _, err := c.Write("abcdefghijkl", bytes.NewReader(make([]byte, 1001)), 0, util.NewFixedLimiter(1000)) - require.Equal(t, util.ErrLimitReached, err) - require.NoFileExists(t, dir+"/abcdefghijkl") -} - -func TestFileStore_Write_UntrustedContentLengthExact(t *testing.T) { - dir, c := newTestFileStore(t) - size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 11) - require.Nil(t, err) - require.Equal(t, int64(11), size) - require.Equal(t, "hello world", readFile(t, dir+"/abcdefghijkl")) -} - -func TestFileStore_Write_UntrustedContentLengthBodyLonger(t *testing.T) { - dir, c := newTestFileStore(t) - // Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored - size, err := c.Write("abcdefghijkl", strings.NewReader("hello world"), 5) - require.Nil(t, err) - require.Equal(t, int64(5), size) - require.Equal(t, "hello", readFile(t, dir+"/abcdefghijkl")) -} - -func TestFileStore_Write_UntrustedContentLengthBodyShorter(t *testing.T) { - dir, c := newTestFileStore(t) - // Body has 5 bytes, but we claim 100 — should fail with content length mismatch - _, err := c.Write("abcdefghijkl", strings.NewReader("hello"), 100) - require.Error(t, err) - require.Contains(t, err.Error(), "content length mismatch") - require.NoFileExists(t, dir+"/abcdefghijkl") -} - -func TestFileStore_Read_NotFound(t *testing.T) { - _, c := newTestFileStore(t) - _, _, err := c.Read("abcdefghijkl") - require.Error(t, err) -} - -func TestFileStore_Sync(t *testing.T) { - dir, c := newTestFileStore(t) - - // Write some files - _, err := c.Write("abcdefghijk0", strings.NewReader("file0"), 0) - require.Nil(t, err) - _, err = c.Write("abcdefghijk1", strings.NewReader("file1"), 0) - require.Nil(t, err) - _, err = c.Write("abcdefghijk2", strings.NewReader("file2"), 0) - require.Nil(t, err) - - require.Equal(t, int64(15), c.Size()) - - // Set the ID provider to only know about file 0 and 2 - c.localIDs = func() ([]string, error) { - return []string{"abcdefghijk0", "abcdefghijk2"}, nil - } - - // Make file 1's mod time old enough to be cleaned up (> 1 hour) - oldTime := time.Unix(1, 0) - os.Chtimes(dir+"/abcdefghijk1", oldTime, oldTime) - - // Run sync - require.Nil(t, c.sync()) - - // File 1 should be deleted (orphan, old enough) - require.NoFileExists(t, dir+"/abcdefghijk1") - require.FileExists(t, dir+"/abcdefghijk0") - require.FileExists(t, dir+"/abcdefghijk2") - - // Size should be updated - require.Equal(t, int64(10), c.Size()) -} - -func TestFileStore_Sync_SkipsRecentFiles(t *testing.T) { - dir, c := newTestFileStore(t) - - // Write a file - _, err := c.Write("abcdefghijk0", strings.NewReader("file0"), 0) - require.Nil(t, err) - - // Set the ID provider to return empty (no valid IDs) - c.localIDs = func() ([]string, error) { - return []string{}, nil - } - - // File was just created, so it should NOT be deleted (< 1 hour old) - require.Nil(t, c.sync()) - require.FileExists(t, dir+"/abcdefghijk0") -} - -func newTestFileStore(t *testing.T) (dir string, cache *Store) { +func newTestFileStore(t *testing.T, totalSizeLimit int64) (dir string, cache *Store) { t.Helper() dir = t.TempDir() - cache, err := NewFileStore(dir, 10*1024, nil) + cache, err := NewFileStore(dir, totalSizeLimit, nil) require.Nil(t, err) t.Cleanup(func() { cache.Close() }) return dir, cache } - -func readFile(t *testing.T, f string) string { - t.Helper() - b, err := os.ReadFile(f) - require.Nil(t, err) - return string(b) -} diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 37bd0ecb..2d4635ff 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -1,7 +1,6 @@ package attachment import ( - "bytes" "encoding/xml" "fmt" "io" @@ -14,43 +13,12 @@ import ( "github.com/stretchr/testify/require" "heckel.io/ntfy/v2/s3" - "heckel.io/ntfy/v2/util" ) -// --- Integration tests using a mock S3 server --- - -func TestS3Store_WriteReadRemove(t *testing.T) { - server := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - - // Write - size, err := cache.Write("abcdefghijkl", strings.NewReader("hello world"), 0) - require.Nil(t, err) - require.Equal(t, int64(11), size) - require.Equal(t, int64(11), cache.Size()) - - // Read back - reader, readSize, err := cache.Read("abcdefghijkl") - require.Nil(t, err) - require.Equal(t, int64(11), readSize) - data, err := io.ReadAll(reader) - reader.Close() - require.Nil(t, err) - require.Equal(t, "hello world", string(data)) - - // Remove - require.Nil(t, cache.Remove("abcdefghijkl")) - require.Equal(t, int64(0), cache.Size()) - - // Read after remove should fail - _, _, err = cache.Read("abcdefghijkl") - require.Error(t, err) -} +// --- S3-specific tests --- func TestS3Store_WriteNoPrefix(t *testing.T) { - server := newMockS3Server() + server, _ := newMockS3Server() defer server.Close() cache := newTestS3Store(t, server, "my-bucket", "", 10*1024) @@ -67,131 +35,6 @@ func TestS3Store_WriteNoPrefix(t *testing.T) { require.Equal(t, "test", string(data)) } -func TestS3Store_WriteTotalSizeLimit(t *testing.T) { - server := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "pfx", 100) - - // First write fits - _, err := cache.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)), 0) - require.Nil(t, err) - require.Equal(t, int64(80), cache.Size()) - require.Equal(t, int64(20), cache.Remaining()) - - // Second write exceeds total limit - _, err = cache.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)), 0) - require.ErrorIs(t, err, util.ErrLimitReached) -} - -func TestS3Store_WriteFileSizeLimit(t *testing.T) { - server := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - - _, err := cache.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), 0, util.NewFixedLimiter(100)) - require.ErrorIs(t, err, util.ErrLimitReached) -} - -func TestS3Store_WriteRemoveMultiple(t *testing.T) { - server := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - - for i := 0; i < 5; i++ { - _, err := cache.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)), 0) - require.Nil(t, err) - } - require.Equal(t, int64(500), cache.Size()) - - require.Nil(t, cache.Remove("abcdefghijk1", "abcdefghijk3")) - require.Equal(t, int64(300), cache.Size()) -} - -func TestS3Store_ReadNotFound(t *testing.T) { - server := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - - _, _, err := cache.Read("abcdefghijkl") - require.Error(t, err) -} - -func TestS3Store_InvalidID(t *testing.T) { - server := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - - _, err := cache.Write("bad", strings.NewReader("x"), 0) - require.Equal(t, errInvalidFileID, err) - - _, _, err = cache.Read("bad") - require.Equal(t, errInvalidFileID, err) - - err = cache.Remove("bad") - require.Equal(t, errInvalidFileID, err) -} - -func TestS3Store_Sync(t *testing.T) { - server := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "pfx", 10*1024) - - // Write some files - _, err := cache.Write("abcdefghijk0", strings.NewReader("file0"), 0) - require.Nil(t, err) - _, err = cache.Write("abcdefghijk1", strings.NewReader("file1"), 0) - require.Nil(t, err) - _, err = cache.Write("abcdefghijk2", strings.NewReader("file2"), 0) - require.Nil(t, err) - - require.Equal(t, int64(15), cache.Size()) - - // Set the ID provider to only know about file 0 and 2 - // All mock objects have LastModified set to 2 hours ago, so orphans are eligible for deletion - cache.localIDs = func() ([]string, error) { - return []string{"abcdefghijk0", "abcdefghijk2"}, nil - } - - // Run sync - require.Nil(t, cache.sync()) - - // File 1 should be deleted (orphan) - _, _, err = cache.Read("abcdefghijk1") - require.Error(t, err) - - // Size should be updated - require.Equal(t, int64(10), cache.Size()) -} - -func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) { - mockServer := newMockS3ServerWithModTime(time.Now()) - defer mockServer.Close() - - cache := newTestS3Store(t, mockServer, "my-bucket", "pfx", 10*1024) - - _, err := cache.Write("abcdefghijk0", strings.NewReader("file0"), 0) - require.Nil(t, err) - - // Set the ID provider to return empty (no valid IDs) - cache.localIDs = func() ([]string, error) { - return []string{}, nil - } - - // File was "just created" (mock returns recent time), so it should NOT be deleted - require.Nil(t, cache.sync()) - - // File should still exist - reader, _, err := cache.Read("abcdefghijk0") - require.Nil(t, err) - reader.Close() -} - // --- Helpers --- func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store { @@ -219,24 +62,26 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string // ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body - uploads map[string]map[int][]byte // uploadID -> partNumber -> data - nextID int // counter for generating upload IDs - lastModTime time.Time // time to return for LastModified in list responses - mu sync.RWMutex + objects map[string][]byte // full key (bucket/key) -> body + modTimes map[string]time.Time // full key (bucket/key) -> last modified time + uploads map[string]map[int][]byte // uploadID -> partNumber -> data + nextID int // counter for generating upload IDs + mu sync.RWMutex } -func newMockS3Server() *httptest.Server { - return newMockS3ServerWithModTime(time.Now().Add(-2 * time.Hour)) -} - -func newMockS3ServerWithModTime(modTime time.Time) *httptest.Server { +func newMockS3Server() (*httptest.Server, *mockS3Server) { m := &mockS3Server{ - objects: make(map[string][]byte), - uploads: make(map[string]map[int][]byte), - lastModTime: modTime, + objects: make(map[string][]byte), + modTimes: make(map[string]time.Time), + uploads: make(map[string]map[int][]byte), } - return httptest.NewTLSServer(m) + return httptest.NewTLSServer(m), m +} + +func (m *mockS3Server) setModTime(path string, t time.Time) { + m.mu.Lock() + m.modTimes[path] = t + m.mu.Unlock() } func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -274,6 +119,7 @@ func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path st } m.mu.Lock() m.objects[path] = body + m.modTimes[path] = time.Now() m.mu.Unlock() w.WriteHeader(http.StatusOK) } @@ -333,6 +179,7 @@ func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Re assembled = append(assembled, parts[i]...) } m.objects[path] = assembled + m.modTimes[path] = time.Now() delete(m.uploads, uploadID) m.mu.Unlock() @@ -402,7 +249,7 @@ func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucket contents = append(contents, s3ListObject{ Key: objKey, Size: int64(len(body)), - LastModified: m.lastModTime.Format(time.RFC3339), + LastModified: m.modTimes[key].Format(time.RFC3339), }) } } diff --git a/attachment/store_test.go b/attachment/store_test.go new file mode 100644 index 00000000..7b5a6013 --- /dev/null +++ b/attachment/store_test.go @@ -0,0 +1,252 @@ +package attachment + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "heckel.io/ntfy/v2/util" +) + +const testSizeLimit = 10 * 1024 + +func TestStore_WriteReadRemove(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + // Write + size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"), 0) + require.Nil(t, err) + require.Equal(t, int64(11), size) + require.Equal(t, int64(11), s.Size()) + + // Read back + reader, readSize, err := s.Read("abcdefghijkl") + require.Nil(t, err) + require.Equal(t, int64(11), readSize) + data, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "hello world", string(data)) + + // Remove + require.Nil(t, s.Remove("abcdefghijkl")) + require.Equal(t, int64(0), s.Size()) + + // Read after remove should fail + _, _, err = s.Read("abcdefghijkl") + require.Error(t, err) + }) +} + +func TestStore_WriteRemoveMultiple(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + for i := 0; i < 5; i++ { + _, err := s.Write(fmt.Sprintf("abcdefghijk%d", i), bytes.NewReader(make([]byte, 100)), 0) + require.Nil(t, err) + } + require.Equal(t, int64(500), s.Size()) + + require.Nil(t, s.Remove("abcdefghijk1", "abcdefghijk3")) + require.Equal(t, int64(300), s.Size()) + + // Removed files should not be readable + _, _, err := s.Read("abcdefghijk1") + require.Error(t, err) + _, _, err = s.Read("abcdefghijk3") + require.Error(t, err) + + // Remaining files should still be readable + for _, id := range []string{"abcdefghijk0", "abcdefghijk2", "abcdefghijk4"} { + reader, _, err := s.Read(id) + require.Nil(t, err) + reader.Close() + } + }) +} + +func TestStore_WriteTotalSizeLimit(t *testing.T) { + forEachBackend(t, 100, func(t *testing.T, s *Store, _ func(string)) { + // First write fits + _, err := s.Write("abcdefghijk0", bytes.NewReader(make([]byte, 80)), 0) + require.Nil(t, err) + require.Equal(t, int64(80), s.Size()) + require.Equal(t, int64(20), s.Remaining()) + + // Second write exceeds total limit + _, err = s.Write("abcdefghijk1", bytes.NewReader(make([]byte, 50)), 0) + require.ErrorIs(t, err, util.ErrLimitReached) + }) +} + +func TestStore_WriteAdditionalLimiter(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + _, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), 0, util.NewFixedLimiter(100)) + require.ErrorIs(t, err, util.ErrLimitReached) + + // File should not be readable (was cleaned up) + _, _, err = s.Read("abcdefghijkl") + require.Error(t, err) + }) +} + +func TestStore_WriteWithLimiter(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + size, err := s.Write("abcdefghijkl", strings.NewReader("normal file"), 0, util.NewFixedLimiter(999)) + require.Nil(t, err) + require.Equal(t, int64(11), size) + require.Equal(t, int64(11), s.Size()) + }) +} + +func TestStore_ReadNotFound(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + _, _, err := s.Read("abcdefghijkl") + require.Error(t, err) + }) +} + +func TestStore_InvalidID(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + _, err := s.Write("bad", strings.NewReader("x"), 0) + require.Equal(t, errInvalidFileID, err) + + _, _, err = s.Read("bad") + require.Equal(t, errInvalidFileID, err) + + err = s.Remove("bad") + require.Equal(t, errInvalidFileID, err) + }) +} + +func TestStore_WriteUntrustedLengthExact(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"), 11) + require.Nil(t, err) + require.Equal(t, int64(11), size) + + reader, _, err := s.Read("abcdefghijkl") + require.Nil(t, err) + data, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "hello world", string(data)) + }) +} + +func TestStore_WriteUntrustedLengthBodyLonger(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + // Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored + size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"), 5) + require.Nil(t, err) + require.Equal(t, int64(5), size) + + reader, _, err := s.Read("abcdefghijkl") + require.Nil(t, err) + data, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, "hello", string(data)) + }) +} + +func TestStore_WriteUntrustedLengthBodyShorter(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + // Body has 5 bytes, but we claim 100 — should fail + _, err := s.Write("abcdefghijkl", strings.NewReader("hello"), 100) + require.Error(t, err) + + // File should not be readable (was cleaned up) + _, _, err = s.Read("abcdefghijkl") + require.Error(t, err) + }) +} + +func TestStore_Sync(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, makeOld func(string)) { + // Write some files + _, err := s.Write("abcdefghijk0", strings.NewReader("file0"), 0) + require.Nil(t, err) + _, err = s.Write("abcdefghijk1", strings.NewReader("file1"), 0) + require.Nil(t, err) + _, err = s.Write("abcdefghijk2", strings.NewReader("file2"), 0) + require.Nil(t, err) + + require.Equal(t, int64(15), s.Size()) + + // Set the ID provider to only know about file 0 and 2 + s.localIDs = func() ([]string, error) { + return []string{"abcdefghijk0", "abcdefghijk2"}, nil + } + + // Make file 1 old enough to be cleaned up + makeOld("abcdefghijk1") + + // Run sync + require.Nil(t, s.sync()) + + // File 1 should be deleted (orphan, old enough) + _, _, err = s.Read("abcdefghijk1") + require.Error(t, err) + + // Files 0 and 2 should still be readable + r, _, err := s.Read("abcdefghijk0") + require.Nil(t, err) + r.Close() + r, _, err = s.Read("abcdefghijk2") + require.Nil(t, err) + r.Close() + + // Size should be updated + require.Equal(t, int64(10), s.Size()) + }) +} + +func TestStore_Sync_SkipsRecentFiles(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + // Write a file + _, err := s.Write("abcdefghijk0", strings.NewReader("file0"), 0) + require.Nil(t, err) + + // Set the ID provider to return empty (no valid IDs) + s.localIDs = func() ([]string, error) { + return []string{}, nil + } + + // File was just created, so it should NOT be deleted (< 1 hour old) + require.Nil(t, s.sync()) + + // File should still exist + reader, _, err := s.Read("abcdefghijk0") + require.Nil(t, err) + reader.Close() + }) +} + +// forEachBackend runs f against both the file and S3 backends. It also provides a makeOld +// callback that makes a specific object's timestamp old enough for orphan cleanup (> 1 hour). +// For the file backend, this uses os.Chtimes; for the S3 backend, it sets the object's +// LastModified time in the mock server. Objects start with recent timestamps by default. +func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *Store, makeOld func(string))) { + t.Run("file", func(t *testing.T) { + dir, s := newTestFileStore(t, totalSizeLimit) + makeOld := func(id string) { + oldTime := time.Unix(1, 0) + os.Chtimes(filepath.Join(dir, id), oldTime, oldTime) + } + f(t, s, makeOld) + }) + t.Run("s3", func(t *testing.T) { + server, mock := newMockS3Server() + defer server.Close() + s := newTestS3Store(t, server, "my-bucket", "pfx", totalSizeLimit) + makeOld := func(id string) { + mock.setModTime("my-bucket/pfx/"+id, time.Unix(1, 0)) + } + f(t, s, makeOld) + }) +} From f2d4575831a3f48014c56576e81686ce19312d79 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 08:15:23 -0400 Subject: [PATCH 19/32] Use real S3 for tests --- attachment/store_s3_test.go | 299 ++++++++---------------------------- attachment/store_test.go | 11 +- s3/client.go | 2 +- 3 files changed, 71 insertions(+), 241 deletions(-) diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 2d4635ff..a41c6f8b 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -1,11 +1,9 @@ package attachment import ( - "encoding/xml" - "fmt" + "context" "io" - "net/http" - "net/http/httptest" + "os" "strings" "sync" "testing" @@ -15,13 +13,23 @@ import ( "heckel.io/ntfy/v2/s3" ) -// --- S3-specific tests --- - -func TestS3Store_WriteNoPrefix(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - - cache := newTestS3Store(t, server, "my-bucket", "", 10*1024) +func TestS3Store_WriteWithPrefix(t *testing.T) { + s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL") + if s3URL == "" { + t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set") + } + cfg, err := s3.ParseURL(s3URL) + require.Nil(t, err) + cfg.Prefix = "test-prefix" + client := s3.New(cfg) + deleteAllObjects(client) + backend := newS3Backend(client) + cache, err := newStore(backend, 10*1024, nil) + require.Nil(t, err) + t.Cleanup(func() { + deleteAllObjects(client) + cache.Close() + }) size, err := cache.Write("abcdefghijkl", strings.NewReader("test"), 0) require.Nil(t, err) @@ -37,241 +45,64 @@ func TestS3Store_WriteNoPrefix(t *testing.T) { // --- Helpers --- -func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store { +func newTestRealS3Store(t *testing.T, totalSizeLimit int64) (*Store, *modTimeOverrideBackend) { t.Helper() - host := strings.TrimPrefix(server.URL, "https://") - backend := newS3Backend(s3.New(&s3.Config{ - AccessKey: "AKID", - SecretKey: "SECRET", - Region: "us-east-1", - Endpoint: host, - Bucket: bucket, - Prefix: prefix, - PathStyle: true, - HTTPClient: server.Client(), - })) - cache, err := newStore(backend, totalSizeLimit, nil) + s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL") + if s3URL == "" { + t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set") + } + cfg, err := s3.ParseURL(s3URL) require.Nil(t, err) - t.Cleanup(func() { cache.Close() }) - return cache + client := s3.New(cfg) + inner := newS3Backend(client) + wrapper := &modTimeOverrideBackend{backend: inner, modTimes: make(map[string]time.Time)} + deleteAllObjects(client) + store, err := newStore(wrapper, totalSizeLimit, nil) + require.Nil(t, err) + t.Cleanup(func() { + deleteAllObjects(client) + store.Close() + }) + return store, wrapper } -// --- Mock S3 server --- -// -// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and -// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. - -type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body - modTimes map[string]time.Time // full key (bucket/key) -> last modified time - uploads map[string]map[int][]byte // uploadID -> partNumber -> data - nextID int // counter for generating upload IDs - mu sync.RWMutex -} - -func newMockS3Server() (*httptest.Server, *mockS3Server) { - m := &mockS3Server{ - objects: make(map[string][]byte), - modTimes: make(map[string]time.Time), - uploads: make(map[string]map[int][]byte), +func deleteAllObjects(client *s3.Client) { + objects, _ := client.ListObjectsV2(context.Background()) + keys := make([]string, 0, len(objects)) + for _, obj := range objects { + keys = append(keys, obj.Key) } - return httptest.NewTLSServer(m), m -} - -func (m *mockS3Server) setModTime(path string, t time.Time) { - m.mu.Lock() - m.modTimes[path] = t - m.mu.Unlock() -} - -func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Path is /{bucket}[/{key...}] - path := strings.TrimPrefix(r.URL.Path, "/") - q := r.URL.Query() - - switch { - case r.Method == http.MethodPut && q.Has("partNumber"): - m.handleUploadPart(w, r, path) - case r.Method == http.MethodPut: - m.handlePut(w, r, path) - case r.Method == http.MethodPost && q.Has("uploads"): - m.handleInitiateMultipart(w, r, path) - case r.Method == http.MethodPost && q.Has("uploadId"): - m.handleCompleteMultipart(w, r, path) - case r.Method == http.MethodDelete && q.Has("uploadId"): - m.handleAbortMultipart(w, r, path) - case r.Method == http.MethodGet && q.Get("list-type") == "2": - m.handleList(w, r, path) - case r.Method == http.MethodGet: - m.handleGet(w, r, path) - case r.Method == http.MethodPost && q.Has("delete"): - m.handleDelete(w, r, path) - default: - http.Error(w, "not implemented", http.StatusNotImplemented) + if len(keys) > 0 { + client.DeleteObjects(context.Background(), keys) //nolint:errcheck } } -func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) { - body, err := io.ReadAll(r.Body) +// modTimeOverrideBackend wraps a backend and allows overriding LastModified times returned by List(). +// This is used in tests to simulate old objects on backends (like real S3) where +// LastModified cannot be set directly. +type modTimeOverrideBackend struct { + backend + mu sync.Mutex + modTimes map[string]time.Time // object ID -> override time +} + +func (b *modTimeOverrideBackend) List() ([]object, error) { + objects, err := b.backend.List() if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + return nil, err } - m.mu.Lock() - m.objects[path] = body - m.modTimes[path] = time.Now() - m.mu.Unlock() - w.WriteHeader(http.StatusOK) -} - -func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) { - m.mu.Lock() - m.nextID++ - uploadID := fmt.Sprintf("upload-%d", m.nextID) - m.uploads[uploadID] = make(map[int][]byte) - m.mu.Unlock() - - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `%s`, uploadID) -} - -func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - var partNumber int - fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber) - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - m.mu.Lock() - parts, ok := m.uploads[uploadID] - if !ok { - m.mu.Unlock() - http.Error(w, "NoSuchUpload", http.StatusNotFound) - return - } - parts[partNumber] = body - m.mu.Unlock() - - etag := fmt.Sprintf(`"etag-part-%d"`, partNumber) - w.Header().Set("ETag", etag) - w.WriteHeader(http.StatusOK) -} - -func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - - m.mu.Lock() - parts, ok := m.uploads[uploadID] - if !ok { - m.mu.Unlock() - http.Error(w, "NoSuchUpload", http.StatusNotFound) - return - } - - // Assemble parts in order - var assembled []byte - for i := 1; i <= len(parts); i++ { - assembled = append(assembled, parts[i]...) - } - m.objects[path] = assembled - m.modTimes[path] = time.Now() - delete(m.uploads, uploadID) - m.mu.Unlock() - - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `%s`, path) -} - -func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - m.mu.Lock() - delete(m.uploads, uploadID) - m.mu.Unlock() - w.WriteHeader(http.StatusNoContent) -} - -func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { - m.mu.RLock() - body, ok := m.objects[path] - m.mu.RUnlock() - if !ok { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte(`NoSuchKeyThe specified key does not exist.`)) - return - } - w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) - w.WriteHeader(http.StatusOK) - w.Write(body) -} - -func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) { - // bucketPath is just the bucket name - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - var req struct { - Objects []struct { - Key string `xml:"Key"` - } `xml:"Object"` - } - if err := xml.Unmarshal(body, &req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - m.mu.Lock() - for _, obj := range req.Objects { - delete(m.objects, bucketPath+"/"+obj.Key) - } - m.mu.Unlock() - w.WriteHeader(http.StatusOK) - w.Write([]byte(``)) -} - -func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) { - prefix := r.URL.Query().Get("prefix") - m.mu.RLock() - var contents []s3ListObject - for key, body := range m.objects { - // key is "bucket/objectkey", strip bucket prefix - objKey := strings.TrimPrefix(key, bucketPath+"/") - if objKey == key { - continue // different bucket - } - if prefix == "" || strings.HasPrefix(objKey, prefix) { - contents = append(contents, s3ListObject{ - Key: objKey, - Size: int64(len(body)), - LastModified: m.modTimes[key].Format(time.RFC3339), - }) + b.mu.Lock() + defer b.mu.Unlock() + for i, obj := range objects { + if t, ok := b.modTimes[obj.ID]; ok { + objects[i].LastModified = t } } - m.mu.RUnlock() - - resp := s3ListResponse{ - Contents: contents, - IsTruncated: false, - } - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - xml.NewEncoder(w).Encode(resp) + return objects, nil } -type s3ListResponse struct { - XMLName xml.Name `xml:"ListBucketResult"` - Contents []s3ListObject `xml:"Contents"` - IsTruncated bool `xml:"IsTruncated"` -} - -type s3ListObject struct { - Key string `xml:"Key"` - Size int64 `xml:"Size"` - LastModified string `xml:"LastModified"` +func (b *modTimeOverrideBackend) setModTime(id string, t time.Time) { + b.mu.Lock() + b.modTimes[id] = t + b.mu.Unlock() } diff --git a/attachment/store_test.go b/attachment/store_test.go index 7b5a6013..645a2159 100644 --- a/attachment/store_test.go +++ b/attachment/store_test.go @@ -229,8 +229,9 @@ func TestStore_Sync_SkipsRecentFiles(t *testing.T) { // forEachBackend runs f against both the file and S3 backends. It also provides a makeOld // callback that makes a specific object's timestamp old enough for orphan cleanup (> 1 hour). -// For the file backend, this uses os.Chtimes; for the S3 backend, it sets the object's -// LastModified time in the mock server. Objects start with recent timestamps by default. +// For the file backend, this uses os.Chtimes; for the S3 backend, it overrides the object's +// LastModified time via a modTimeOverrideBackend wrapper. Objects start with recent timestamps +// by default. The S3 subtest is skipped if NTFY_TEST_ATTACHMENT_S3_URL is not set. func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *Store, makeOld func(string))) { t.Run("file", func(t *testing.T) { dir, s := newTestFileStore(t, totalSizeLimit) @@ -241,11 +242,9 @@ func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s * f(t, s, makeOld) }) t.Run("s3", func(t *testing.T) { - server, mock := newMockS3Server() - defer server.Close() - s := newTestS3Store(t, server, "my-bucket", "pfx", totalSizeLimit) + s, wrapper := newTestRealS3Store(t, totalSizeLimit) makeOld := func(id string) { - mock.setModTime("my-bucket/pfx/"+id, time.Unix(1, 0)) + wrapper.setModTime(id, time.Unix(1, 0)) } f(t, s, makeOld) }) diff --git a/s3/client.go b/s3/client.go index 29cad3a4..d9ec1ab8 100644 --- a/s3/client.go +++ b/s3/client.go @@ -86,7 +86,7 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size if err != nil { return fmt.Errorf("uploading object %s failed: %w", key, err) } - resp.Body.Close() + defer resp.Body.Close() if !isHTTPSuccess(resp) { return parseError(resp) } From f76135c5ab21e6f739a12697326292f8bc07efce Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 08:23:46 -0400 Subject: [PATCH 20/32] Large objects test --- attachment/store_test.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/attachment/store_test.go b/attachment/store_test.go index 645a2159..ec658114 100644 --- a/attachment/store_test.go +++ b/attachment/store_test.go @@ -123,6 +123,37 @@ func TestStore_InvalidID(t *testing.T) { }) } +func TestStore_WriteLargeObjects(t *testing.T) { + sizes := map[string]int64{ + "100B": 100, + "6MB": 6 * 1024 * 1024, + "12MB": 12 * 1024 * 1024, + } + for name, sz := range sizes { + t.Run(name, func(t *testing.T) { + forEachBackend(t, sz+1024, func(t *testing.T, s *Store, _ func(string)) { + data := make([]byte, sz) + for i := range data { + data[i] = byte(i % 251) + } + + size, err := s.Write("abcdefghijkl", bytes.NewReader(data), 0) + require.Nil(t, err) + require.Equal(t, sz, size) + require.Equal(t, sz, s.Size()) + + reader, readSize, err := s.Read("abcdefghijkl") + require.Nil(t, err) + require.Equal(t, sz, readSize) + got, err := io.ReadAll(reader) + reader.Close() + require.Nil(t, err) + require.Equal(t, data, got) + }) + }) + } +} + func TestStore_WriteUntrustedLengthExact(t *testing.T) { forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { size, err := s.Write("abcdefghijkl", strings.NewReader("hello world"), 11) From fa33d63138aa5dc230b3ef84e21fe5d63407d15d Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 08:26:18 -0400 Subject: [PATCH 21/32] More tests --- attachment/store_test.go | 70 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/attachment/store_test.go b/attachment/store_test.go index ec658114..7ac7cddb 100644 --- a/attachment/store_test.go +++ b/attachment/store_test.go @@ -103,6 +103,76 @@ func TestStore_WriteWithLimiter(t *testing.T) { }) } +func TestStore_WriteOverwriteSameID(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + // Write 100 bytes + _, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 100)), 0) + require.Nil(t, err) + require.Equal(t, int64(100), s.Size()) + + // Overwrite with 50 bytes + _, err = s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 50)), 0) + require.Nil(t, err) + require.Equal(t, int64(150), s.Size()) // Store tracks both writes + + // Read back should return the latest content + reader, readSize, err := s.Read("abcdefghijkl") + require.Nil(t, err) + require.Equal(t, int64(50), readSize) + reader.Close() + }) +} + +func TestStore_WriteAfterFailure(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { + // Failed write: limiter rejects it + _, err := s.Write("abcdefghijkl", bytes.NewReader(make([]byte, 200)), 0, util.NewFixedLimiter(100)) + require.ErrorIs(t, err, util.ErrLimitReached) + require.Equal(t, int64(0), s.Size()) + + // Subsequent write with a different ID should succeed + size, err := s.Write("abcdefghijk2", strings.NewReader("hello"), 0) + require.Nil(t, err) + require.Equal(t, int64(5), size) + require.Equal(t, int64(5), s.Size()) + + // The failed ID should not be readable + _, _, err = s.Read("abcdefghijkl") + require.Error(t, err) + + // The successful ID should be readable + reader, _, err := s.Read("abcdefghijk2") + require.Nil(t, err) + reader.Close() + }) +} + +func TestStore_SyncRecomputesSize(t *testing.T) { + forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, makeOld func(string)) { + // Write two files + _, err := s.Write("abcdefghijk0", bytes.NewReader(make([]byte, 100)), 0) + require.Nil(t, err) + _, err = s.Write("abcdefghijk1", bytes.NewReader(make([]byte, 200)), 0) + require.Nil(t, err) + require.Equal(t, int64(300), s.Size()) + + // Corrupt the in-memory size tracking + s.mu.Lock() + s.size = 999 + s.mu.Unlock() + require.Equal(t, int64(999), s.Size()) + + // Set localIDs to include both files so nothing gets deleted + s.localIDs = func() ([]string, error) { + return []string{"abcdefghijk0", "abcdefghijk1"}, nil + } + + // Sync should recompute size from the backend + require.Nil(t, s.sync()) + require.Equal(t, int64(300), s.Size()) + }) +} + func TestStore_ReadNotFound(t *testing.T) { forEachBackend(t, testSizeLimit, func(t *testing.T, s *Store, _ func(string)) { _, _, err := s.Read("abcdefghijkl") From 56b63c475ceffb9ba40364c78f0c9c165856bc7b Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 08:28:18 -0400 Subject: [PATCH 22/32] Remove MinIO mention --- docs/config.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/config.md b/docs/config.md index 853d4891..ae7547b3 100644 --- a/docs/config.md +++ b/docs/config.md @@ -525,7 +525,7 @@ Here's an example config using the local filesystem for attachment storage: ### S3 storage As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3, -MinIO, DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage. +DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage. To use S3, set `attachment-cache-dir` to an S3 URL with the following format: @@ -533,7 +533,7 @@ To use S3, set `attachment-cache-dir` to an S3 URL with the following format: s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] ``` -When `endpoint` is specified, path-style addressing is enabled automatically (useful for MinIO and other S3-compatible stores). +When `endpoint` is specified, path-style addressing is enabled automatically (useful for S3-compatible stores like DigitalOcean Spaces). === "/etc/ntfy/server.yml (AWS S3)" ``` yaml @@ -541,7 +541,7 @@ When `endpoint` is specified, path-style addressing is enabled automatically (us attachment-cache-dir: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1" ``` -=== "/etc/ntfy/server.yml (MinIO/custom endpoint)" +=== "/etc/ntfy/server.yml (custom endpoint)" ``` yaml base-url: "https://ntfy.sh" attachment-cache-dir: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com" From 536c6f58072504a5c38ed04339785dd0ba8be609 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 08:38:41 -0400 Subject: [PATCH 23/32] More consistent logging --- attachment/backend_file.go | 8 +- attachment/backend_s3.go | 12 +- attachment/store.go | 3 + go.mod | 2 +- go.sum | 4 +- server/server.go | 28 ++--- server/server_manager.go | 12 +- web/package-lock.json | 224 ++++++++++++++++++------------------- 8 files changed, 142 insertions(+), 151 deletions(-) diff --git a/attachment/backend_file.go b/attachment/backend_file.go index 8726ddf4..e86ff1ec 100644 --- a/attachment/backend_file.go +++ b/attachment/backend_file.go @@ -6,12 +6,8 @@ import ( "os" "path/filepath" "time" - - "heckel.io/ntfy/v2/log" ) -const tagFileBackend = "attachment_file" - type fileBackend struct { dir string } @@ -86,8 +82,8 @@ func (b *fileBackend) Get(id string) (io.ReadCloser, int64, error) { func (b *fileBackend) Delete(ids ...string) error { for _, id := range ids { file := filepath.Join(b.dir, id) - if err := os.Remove(file); err != nil { - log.Tag(tagFileBackend).Field("message_id", id).Err(err).Debug("Error deleting attachment") + if err := os.Remove(file); err != nil && !os.IsNotExist(err) { + return err } } return nil diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index 081d6002..44f946f6 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -5,14 +5,10 @@ import ( "io" "time" - "heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/s3" ) -const ( - tagS3Backend = "attachment_s3" - deleteBatchSize = 1000 -) +const deleteBatchSize = 1000 type s3Backend struct { client *s3.Client @@ -55,11 +51,7 @@ func (b *s3Backend) Delete(ids ...string) error { if end > len(ids) { end = len(ids) } - batch := ids[i:end] - for _, id := range batch { - log.Tag(tagS3Backend).Field("message_id", id).Debug("Deleting attachment from S3") - } - if err := b.client.DeleteObjects(context.Background(), batch); err != nil { + if err := b.client.DeleteObjects(context.Background(), ids[i:end]); err != nil { return err } } diff --git a/attachment/store.go b/attachment/store.go index 6e7cfb99..7d6b0620 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -113,6 +113,9 @@ func (c *Store) Remove(ids ...string) error { } } // Remove from backend + for _, id := range ids { + log.Tag(tagStore).Field("message_id", id).Debug("Removing attachment") + } if err := c.backend.Delete(ids...); err != nil { return err } diff --git a/go.mod b/go.mod index 4f0451b6..7edd3710 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require github.com/pkg/errors v0.9.1 // indirect require ( firebase.google.com/go/v4 v4.19.0 github.com/SherClockHolmes/webpush-go v1.4.0 - github.com/jackc/pgx/v5 v5.8.0 + github.com/jackc/pgx/v5 v5.9.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/prometheus/client_golang v1.23.2 github.com/stripe/stripe-go/v74 v74.30.0 diff --git a/go.sum b/go.sum index 0851929d..97738e0f 100644 --- a/go.sum +++ b/go.sum @@ -108,8 +108,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= -github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/pgx/v5 v5.9.0 h1:T/dI+2TvmI2H8s/KH1/lXIbz1CUFk3gn5oTjr0/mBsE= +github.com/jackc/pgx/v5 v5.9.0/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= diff --git a/server/server.go b/server/server.go index 87eee5d6..77e7b0c0 100644 --- a/server/server.go +++ b/server/server.go @@ -65,7 +65,7 @@ type Server struct { userManager *user.Manager // Might be nil! messageCache *message.Cache // Database that stores the messages webPush *webpush.Store // Database that stores web push subscriptions - fileCache *attachment.Store // Attachment store (file system or S3) + attachment *attachment.Store // Attachment store (file system or S3) stripe stripeAPI // Stripe API, can be replaced with a mock priceCache *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!) metricsHandler http.Handler // Handles /metrics if enable-metrics set, and listen-metrics-http not set @@ -229,7 +229,7 @@ func New(conf *Config) (*Server, error) { if err != nil { return nil, err } - fileCache, err := createAttachmentStore(conf, messageCache) + attachmentStore, err := createAttachmentStore(conf, messageCache) if err != nil { return nil, err } @@ -275,7 +275,7 @@ func New(conf *Config) (*Server, error) { db: pool, messageCache: messageCache, webPush: wp, - fileCache: fileCache, + attachment: attachmentStore, firebaseClient: firebaseClient, smtpSender: mailer, topics: topics, @@ -432,8 +432,8 @@ func (s *Server) Stop() { if s.smtpServer != nil { s.smtpServer.Close() } - if s.fileCache != nil { - s.fileCache.Close() + if s.attachment != nil { + s.attachment.Close() } s.closeDatabases() close(s.closeChan) @@ -609,7 +609,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.ensureWebEnabled(s.handleStatic)(w, r, v) } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { return s.ensureWebEnabled(s.handleDocs)(w, r, v) - } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.fileCache != nil { + } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.attachment != nil { return s.limitRequests(s.handleFile)(w, r, v) } else if r.Method == http.MethodOptions { return s.limitRequests(s.handleOptions)(w, r, v) // Should work even if the web app is not enabled, see #598 @@ -766,7 +766,7 @@ func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor) // Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it // can associate the download bandwidth with the uploader. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error { - if s.fileCache == nil { + if s.attachment == nil { return errHTTPInternalError } matches := fileRegex.FindStringSubmatch(r.URL.Path) @@ -774,7 +774,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) return errHTTPInternalErrorInvalidPath } messageID := matches[1] - reader, size, err := s.fileCache.Read(messageID) + reader, size, err := s.attachment.Read(messageID) if err != nil { return errHTTPNotFound.Fields(log.Context{ "message_id": messageID, @@ -935,8 +935,8 @@ func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*model.Mess return nil, err } // Delete attachment files for deleted scheduled messages - if s.fileCache != nil && len(deletedIDs) > 0 { - if err := s.fileCache.Remove(deletedIDs...); err != nil { + if s.attachment != nil && len(deletedIDs) > 0 { + if err := s.attachment.Remove(deletedIDs...); err != nil { logvrm(v, r, m).Tag(tagPublish).Err(err).Warn("Error removing attachments for deleted scheduled messages") } } @@ -1042,8 +1042,8 @@ func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v * return err } // Delete attachment files for deleted scheduled messages - if s.fileCache != nil && len(deletedIDs) > 0 { - if err := s.fileCache.Remove(deletedIDs...); err != nil { + if s.attachment != nil && len(deletedIDs) > 0 { + if err := s.attachment.Remove(deletedIDs...); err != nil { logvrm(v, r, m).Tag(tagPublish).Err(err).Warn("Error removing attachments for deleted scheduled messages") } } @@ -1421,7 +1421,7 @@ func (s *Server) renderTemplate(name, tpl, source string) (string, error) { } func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Message, body *util.PeekedReadCloser) error { - if s.fileCache == nil || s.config.BaseURL == "" { + if s.attachment == nil || s.config.BaseURL == "" { return errHTTPBadRequestAttachmentsDisallowed.With(m) } vinfo, err := v.Info() @@ -1458,7 +1458,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *model.Me util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit), util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining), } - m.Attachment.Size, err = s.fileCache.Write(m.ID, body, r.ContentLength, limiters...) + m.Attachment.Size, err = s.attachment.Write(m.ID, body, r.ContentLength, limiters...) if errors.Is(err, util.ErrLimitReached) { return errHTTPEntityTooLargeAttachment.With(m) } else if err != nil { diff --git a/server/server_manager.go b/server/server_manager.go index 5bf42924..89ff38c2 100644 --- a/server/server_manager.go +++ b/server/server_manager.go @@ -99,8 +99,8 @@ func (s *Server) execManager() { mset(metricUsers, usersCount) mset(metricSubscribers, subscribers) mset(metricTopics, topicsCount) - if s.fileCache != nil { - mset(metricAttachmentsTotalSize, s.fileCache.Size()) + if s.attachment != nil { + mset(metricAttachmentsTotalSize, s.attachment.Size()) } } @@ -140,7 +140,7 @@ func (s *Server) pruneTokens() { } func (s *Server) pruneAttachments() { - if s.fileCache == nil { + if s.attachment == nil { return } log. @@ -153,7 +153,7 @@ func (s *Server) pruneAttachments() { if log.Tag(tagManager).IsDebug() { log.Tag(tagManager).Debug("Deleting attachments %s", strings.Join(ids, ", ")) } - if err := s.fileCache.Remove(ids...); err != nil { + if err := s.attachment.Remove(ids...); err != nil { log.Tag(tagManager).Err(err).Warn("Error deleting attachments") } if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil { @@ -174,8 +174,8 @@ func (s *Server) pruneMessages() { if err != nil { log.Tag(tagManager).Err(err).Warn("Error retrieving expired messages") } else if len(expiredMessageIDs) > 0 { - if s.fileCache != nil { - if err := s.fileCache.Remove(expiredMessageIDs...); err != nil { + if s.attachment != nil { + if err := s.attachment.Remove(expiredMessageIDs...); err != nil { log.Tag(tagManager).Err(err).Warn("Error deleting attachments for expired messages") } } diff --git a/web/package-lock.json b/web/package-lock.json index e2ec6f9f..175ef11b 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -2738,9 +2738,9 @@ } }, "node_modules/@rollup/rollup-android-arm-eabi": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.59.0.tgz", - "integrity": "sha512-upnNBkA6ZH2VKGcBj9Fyl9IGNPULcjXRlg0LLeaioQWueH30p6IXtJEbKAgvyv+mJaMxSm1l6xwDXYjpEMiLMg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.60.0.tgz", + "integrity": "sha512-WOhNW9K8bR3kf4zLxbfg6Pxu2ybOUbB2AjMDHSQx86LIF4rH4Ft7vmMwNt0loO0eonglSNy4cpD3MKXXKQu0/A==", "cpu": [ "arm" ], @@ -2752,9 +2752,9 @@ ] }, "node_modules/@rollup/rollup-android-arm64": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.59.0.tgz", - "integrity": "sha512-hZ+Zxj3SySm4A/DylsDKZAeVg0mvi++0PYVceVyX7hemkw7OreKdCvW2oQ3T1FMZvCaQXqOTHb8qmBShoqk69Q==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.60.0.tgz", + "integrity": "sha512-u6JHLll5QKRvjciE78bQXDmqRqNs5M/3GVqZeMwvmjaNODJih/WIrJlFVEihvV0MiYFmd+ZyPr9wxOVbPAG2Iw==", "cpu": [ "arm64" ], @@ -2766,9 +2766,9 @@ ] }, "node_modules/@rollup/rollup-darwin-arm64": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.59.0.tgz", - "integrity": "sha512-W2Psnbh1J8ZJw0xKAd8zdNgF9HRLkdWwwdWqubSVk0pUuQkoHnv7rx4GiF9rT4t5DIZGAsConRE3AxCdJ4m8rg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.60.0.tgz", + "integrity": "sha512-qEF7CsKKzSRc20Ciu2Zw1wRrBz4g56F7r/vRwY430UPp/nt1x21Q/fpJ9N5l47WWvJlkNCPJz3QRVw008fi7yA==", "cpu": [ "arm64" ], @@ -2780,9 +2780,9 @@ ] }, "node_modules/@rollup/rollup-darwin-x64": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.59.0.tgz", - "integrity": "sha512-ZW2KkwlS4lwTv7ZVsYDiARfFCnSGhzYPdiOU4IM2fDbL+QGlyAbjgSFuqNRbSthybLbIJ915UtZBtmuLrQAT/w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.60.0.tgz", + "integrity": "sha512-WADYozJ4QCnXCH4wPB+3FuGmDPoFseVCUrANmA5LWwGmC6FL14BWC7pcq+FstOZv3baGX65tZ378uT6WG8ynTw==", "cpu": [ "x64" ], @@ -2794,9 +2794,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-arm64": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.59.0.tgz", - "integrity": "sha512-EsKaJ5ytAu9jI3lonzn3BgG8iRBjV4LxZexygcQbpiU0wU0ATxhNVEpXKfUa0pS05gTcSDMKpn3Sx+QB9RlTTA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.60.0.tgz", + "integrity": "sha512-6b8wGHJlDrGeSE3aH5mGNHBjA0TTkxdoNHik5EkvPHCt351XnigA4pS7Wsj/Eo9Y8RBU6f35cjN9SYmCFBtzxw==", "cpu": [ "arm64" ], @@ -2808,9 +2808,9 @@ ] }, "node_modules/@rollup/rollup-freebsd-x64": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.59.0.tgz", - "integrity": "sha512-d3DuZi2KzTMjImrxoHIAODUZYoUUMsuUiY4SRRcJy6NJoZ6iIqWnJu9IScV9jXysyGMVuW+KNzZvBLOcpdl3Vg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.60.0.tgz", + "integrity": "sha512-h25Ga0t4jaylMB8M/JKAyrvvfxGRjnPQIR8lnCayyzEjEOx2EJIlIiMbhpWxDRKGKF8jbNH01NnN663dH638mA==", "cpu": [ "x64" ], @@ -2822,9 +2822,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-gnueabihf": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.59.0.tgz", - "integrity": "sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.60.0.tgz", + "integrity": "sha512-RzeBwv0B3qtVBWtcuABtSuCzToo2IEAIQrcyB/b2zMvBWVbjo8bZDjACUpnaafaxhTw2W+imQbP2BD1usasK4g==", "cpu": [ "arm" ], @@ -2836,9 +2836,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm-musleabihf": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.59.0.tgz", - "integrity": "sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.60.0.tgz", + "integrity": "sha512-Sf7zusNI2CIU1HLzuu9Tc5YGAHEZs5Lu7N1ssJG4Tkw6e0MEsN7NdjUDDfGNHy2IU+ENyWT+L2obgWiguWibWQ==", "cpu": [ "arm" ], @@ -2850,9 +2850,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-gnu": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.59.0.tgz", - "integrity": "sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.60.0.tgz", + "integrity": "sha512-DX2x7CMcrJzsE91q7/O02IJQ5/aLkVtYFryqCjduJhUfGKG6yJV8hxaw8pZa93lLEpPTP/ohdN4wFz7yp/ry9A==", "cpu": [ "arm64" ], @@ -2864,9 +2864,9 @@ ] }, "node_modules/@rollup/rollup-linux-arm64-musl": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.59.0.tgz", - "integrity": "sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.60.0.tgz", + "integrity": "sha512-09EL+yFVbJZlhcQfShpswwRZ0Rg+z/CsSELFCnPt3iK+iqwGsI4zht3secj5vLEs957QvFFXnzAT0FFPIxSrkQ==", "cpu": [ "arm64" ], @@ -2878,9 +2878,9 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-gnu": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.59.0.tgz", - "integrity": "sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-gnu/-/rollup-linux-loong64-gnu-4.60.0.tgz", + "integrity": "sha512-i9IcCMPr3EXm8EQg5jnja0Zyc1iFxJjZWlb4wr7U2Wx/GrddOuEafxRdMPRYVaXjgbhvqalp6np07hN1w9kAKw==", "cpu": [ "loong64" ], @@ -2892,9 +2892,9 @@ ] }, "node_modules/@rollup/rollup-linux-loong64-musl": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.59.0.tgz", - "integrity": "sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-loong64-musl/-/rollup-linux-loong64-musl-4.60.0.tgz", + "integrity": "sha512-DGzdJK9kyJ+B78MCkWeGnpXJ91tK/iKA6HwHxF4TAlPIY7GXEvMe8hBFRgdrR9Ly4qebR/7gfUs9y2IoaVEyog==", "cpu": [ "loong64" ], @@ -2906,9 +2906,9 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-gnu": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.59.0.tgz", - "integrity": "sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-gnu/-/rollup-linux-ppc64-gnu-4.60.0.tgz", + "integrity": "sha512-RwpnLsqC8qbS8z1H1AxBA1H6qknR4YpPR9w2XX0vo2Sz10miu57PkNcnHVaZkbqyw/kUWfKMI73jhmfi9BRMUQ==", "cpu": [ "ppc64" ], @@ -2920,9 +2920,9 @@ ] }, "node_modules/@rollup/rollup-linux-ppc64-musl": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.59.0.tgz", - "integrity": "sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-ppc64-musl/-/rollup-linux-ppc64-musl-4.60.0.tgz", + "integrity": "sha512-Z8pPf54Ly3aqtdWC3G4rFigZgNvd+qJlOE52fmko3KST9SoGfAdSRCwyoyG05q1HrrAblLbk1/PSIV+80/pxLg==", "cpu": [ "ppc64" ], @@ -2934,9 +2934,9 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-gnu": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.59.0.tgz", - "integrity": "sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.60.0.tgz", + "integrity": "sha512-3a3qQustp3COCGvnP4SvrMHnPQ9d1vzCakQVRTliaz8cIp/wULGjiGpbcqrkv0WrHTEp8bQD/B3HBjzujVWLOA==", "cpu": [ "riscv64" ], @@ -2948,9 +2948,9 @@ ] }, "node_modules/@rollup/rollup-linux-riscv64-musl": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.59.0.tgz", - "integrity": "sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-musl/-/rollup-linux-riscv64-musl-4.60.0.tgz", + "integrity": "sha512-pjZDsVH/1VsghMJ2/kAaxt6dL0psT6ZexQVrijczOf+PeP2BUqTHYejk3l6TlPRydggINOeNRhvpLa0AYpCWSQ==", "cpu": [ "riscv64" ], @@ -2962,9 +2962,9 @@ ] }, "node_modules/@rollup/rollup-linux-s390x-gnu": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.59.0.tgz", - "integrity": "sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.60.0.tgz", + "integrity": "sha512-3ObQs0BhvPgiUVZrN7gqCSvmFuMWvWvsjG5ayJ3Lraqv+2KhOsp+pUbigqbeWqueGIsnn+09HBw27rJ+gYK4VQ==", "cpu": [ "s390x" ], @@ -2976,9 +2976,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-gnu": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.59.0.tgz", - "integrity": "sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.60.0.tgz", + "integrity": "sha512-EtylprDtQPdS5rXvAayrNDYoJhIz1/vzN2fEubo3yLE7tfAw+948dO0g4M0vkTVFhKojnF+n6C8bDNe+gDRdTg==", "cpu": [ "x64" ], @@ -2990,9 +2990,9 @@ ] }, "node_modules/@rollup/rollup-linux-x64-musl": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.59.0.tgz", - "integrity": "sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.60.0.tgz", + "integrity": "sha512-k09oiRCi/bHU9UVFqD17r3eJR9bn03TyKraCrlz5ULFJGdJGi7VOmm9jl44vOJvRJ6P7WuBi/s2A97LxxHGIdw==", "cpu": [ "x64" ], @@ -3004,9 +3004,9 @@ ] }, "node_modules/@rollup/rollup-openbsd-x64": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.59.0.tgz", - "integrity": "sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openbsd-x64/-/rollup-openbsd-x64-4.60.0.tgz", + "integrity": "sha512-1o/0/pIhozoSaDJoDcec+IVLbnRtQmHwPV730+AOD29lHEEo4F5BEUB24H0OBdhbBBDwIOSuf7vgg0Ywxdfiiw==", "cpu": [ "x64" ], @@ -3018,9 +3018,9 @@ ] }, "node_modules/@rollup/rollup-openharmony-arm64": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.59.0.tgz", - "integrity": "sha512-tt9KBJqaqp5i5HUZzoafHZX8b5Q2Fe7UjYERADll83O4fGqJ49O1FsL6LpdzVFQcpwvnyd0i+K/VSwu/o/nWlA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-openharmony-arm64/-/rollup-openharmony-arm64-4.60.0.tgz", + "integrity": "sha512-pESDkos/PDzYwtyzB5p/UoNU/8fJo68vcXM9ZW2V0kjYayj1KaaUfi1NmTUTUpMn4UhU4gTuK8gIaFO4UGuMbA==", "cpu": [ "arm64" ], @@ -3032,9 +3032,9 @@ ] }, "node_modules/@rollup/rollup-win32-arm64-msvc": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.59.0.tgz", - "integrity": "sha512-V5B6mG7OrGTwnxaNUzZTDTjDS7F75PO1ae6MJYdiMu60sq0CqN5CVeVsbhPxalupvTX8gXVSU9gq+Rx1/hvu6A==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.60.0.tgz", + "integrity": "sha512-hj1wFStD7B1YBeYmvY+lWXZ7ey73YGPcViMShYikqKT1GtstIKQAtfUI6yrzPjAy/O7pO0VLXGmUVWXQMaYgTQ==", "cpu": [ "arm64" ], @@ -3046,9 +3046,9 @@ ] }, "node_modules/@rollup/rollup-win32-ia32-msvc": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.59.0.tgz", - "integrity": "sha512-UKFMHPuM9R0iBegwzKF4y0C4J9u8C6MEJgFuXTBerMk7EJ92GFVFYBfOZaSGLu6COf7FxpQNqhNS4c4icUPqxA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.60.0.tgz", + "integrity": "sha512-SyaIPFoxmUPlNDq5EHkTbiKzmSEmq/gOYFI/3HHJ8iS/v1mbugVa7dXUzcJGQfoytp9DJFLhHH4U3/eTy2Bq4w==", "cpu": [ "ia32" ], @@ -3060,9 +3060,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-gnu": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.59.0.tgz", - "integrity": "sha512-laBkYlSS1n2L8fSo1thDNGrCTQMmxjYY5G0WFWjFFYZkKPjsMBsgJfGf4TLxXrF6RyhI60L8TMOjBMvXiTcxeA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-gnu/-/rollup-win32-x64-gnu-4.60.0.tgz", + "integrity": "sha512-RdcryEfzZr+lAr5kRm2ucN9aVlCCa2QNq4hXelZxb8GG0NJSazq44Z3PCCc8wISRuCVnGs0lQJVX5Vp6fKA+IA==", "cpu": [ "x64" ], @@ -3074,9 +3074,9 @@ ] }, "node_modules/@rollup/rollup-win32-x64-msvc": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.59.0.tgz", - "integrity": "sha512-2HRCml6OztYXyJXAvdDXPKcawukWY2GpR5/nxKp4iBgiO3wcoEGkAaqctIbZcNB6KlUQBIqt8VYkNSj2397EfA==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.60.0.tgz", + "integrity": "sha512-PrsWNQ8BuE00O3Xsx3ALh2Df8fAj9+cvvX9AIA6o4KpATR98c9mud4XtDWVvsEuyia5U4tVSTKygawyJkjm60w==", "cpu": [ "x64" ], @@ -3642,9 +3642,9 @@ "license": "MIT" }, "node_modules/baseline-browser-mapping": { - "version": "2.10.9", - "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.9.tgz", - "integrity": "sha512-OZd0e2mU11ClX8+IdXe3r0dbqMEznRiT4TfbhYIbcRPZkqJ7Qwer8ij3GZAmLsRKa+II9V1v5czCkvmHH3XZBg==", + "version": "2.10.10", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.10.tgz", + "integrity": "sha512-sUoJ3IMxx4AyRqO4MLeHlnGDkyXRoUG0/AI9fjK+vS72ekpV0yWVY7O0BVjmBcRtkNcsAO2QDZ4tdKKGoI6YaQ==", "dev": true, "license": "Apache-2.0", "bin": { @@ -3940,9 +3940,9 @@ } }, "node_modules/cosmiconfig/node_modules/yaml": { - "version": "1.10.2", - "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.2.tgz", - "integrity": "sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==", + "version": "1.10.3", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.3.tgz", + "integrity": "sha512-vIYeF1u3CjlhAFekPPAk2h/Kv4T3mAkMox5OymRiJQB0spDP10LHvt+K7G9Ny6NuuMAb25/6n1qyUjAcGNf/AA==", "license": "ISC", "engines": { "node": ">= 6" @@ -7580,9 +7580,9 @@ } }, "node_modules/rollup": { - "version": "4.59.0", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.59.0.tgz", - "integrity": "sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==", + "version": "4.60.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.60.0.tgz", + "integrity": "sha512-yqjxruMGBQJ2gG4HtjZtAfXArHomazDHoFwFFmZZl0r7Pdo7qCIXKqKHZc8yeoMgzJJ+pO6pEEHa+V7uzWlrAQ==", "dev": true, "license": "MIT", "dependencies": { @@ -7596,31 +7596,31 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.59.0", - "@rollup/rollup-android-arm64": "4.59.0", - "@rollup/rollup-darwin-arm64": "4.59.0", - "@rollup/rollup-darwin-x64": "4.59.0", - "@rollup/rollup-freebsd-arm64": "4.59.0", - "@rollup/rollup-freebsd-x64": "4.59.0", - "@rollup/rollup-linux-arm-gnueabihf": "4.59.0", - "@rollup/rollup-linux-arm-musleabihf": "4.59.0", - "@rollup/rollup-linux-arm64-gnu": "4.59.0", - "@rollup/rollup-linux-arm64-musl": "4.59.0", - "@rollup/rollup-linux-loong64-gnu": "4.59.0", - "@rollup/rollup-linux-loong64-musl": "4.59.0", - "@rollup/rollup-linux-ppc64-gnu": "4.59.0", - "@rollup/rollup-linux-ppc64-musl": "4.59.0", - "@rollup/rollup-linux-riscv64-gnu": "4.59.0", - "@rollup/rollup-linux-riscv64-musl": "4.59.0", - "@rollup/rollup-linux-s390x-gnu": "4.59.0", - "@rollup/rollup-linux-x64-gnu": "4.59.0", - "@rollup/rollup-linux-x64-musl": "4.59.0", - "@rollup/rollup-openbsd-x64": "4.59.0", - "@rollup/rollup-openharmony-arm64": "4.59.0", - "@rollup/rollup-win32-arm64-msvc": "4.59.0", - "@rollup/rollup-win32-ia32-msvc": "4.59.0", - "@rollup/rollup-win32-x64-gnu": "4.59.0", - "@rollup/rollup-win32-x64-msvc": "4.59.0", + "@rollup/rollup-android-arm-eabi": "4.60.0", + "@rollup/rollup-android-arm64": "4.60.0", + "@rollup/rollup-darwin-arm64": "4.60.0", + "@rollup/rollup-darwin-x64": "4.60.0", + "@rollup/rollup-freebsd-arm64": "4.60.0", + "@rollup/rollup-freebsd-x64": "4.60.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.60.0", + "@rollup/rollup-linux-arm-musleabihf": "4.60.0", + "@rollup/rollup-linux-arm64-gnu": "4.60.0", + "@rollup/rollup-linux-arm64-musl": "4.60.0", + "@rollup/rollup-linux-loong64-gnu": "4.60.0", + "@rollup/rollup-linux-loong64-musl": "4.60.0", + "@rollup/rollup-linux-ppc64-gnu": "4.60.0", + "@rollup/rollup-linux-ppc64-musl": "4.60.0", + "@rollup/rollup-linux-riscv64-gnu": "4.60.0", + "@rollup/rollup-linux-riscv64-musl": "4.60.0", + "@rollup/rollup-linux-s390x-gnu": "4.60.0", + "@rollup/rollup-linux-x64-gnu": "4.60.0", + "@rollup/rollup-linux-x64-musl": "4.60.0", + "@rollup/rollup-openbsd-x64": "4.60.0", + "@rollup/rollup-openharmony-arm64": "4.60.0", + "@rollup/rollup-win32-arm64-msvc": "4.60.0", + "@rollup/rollup-win32-ia32-msvc": "4.60.0", + "@rollup/rollup-win32-x64-gnu": "4.60.0", + "@rollup/rollup-win32-x64-msvc": "4.60.0", "fsevents": "~2.3.2" } }, @@ -9515,9 +9515,9 @@ "license": "ISC" }, "node_modules/yaml": { - "version": "2.8.2", - "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.2.tgz", - "integrity": "sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==", + "version": "2.8.3", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.3.tgz", + "integrity": "sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg==", "dev": true, "license": "ISC", "optional": true, From 59ec76e8b2ac6f593049ab18dd6f9d0f0da0ec58 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 15:10:28 -0400 Subject: [PATCH 24/32] Fix brittle tests, move delete batching into client package, run s3 tests against real bucket --- attachment/backend_s3.go | 14 +- attachment/store_s3_test.go | 44 ++- attachment/store_test.go | 2 +- s3/client.go | 23 +- s3/client_test.go | 685 +++++------------------------------- s3/util.go | 3 + 6 files changed, 140 insertions(+), 631 deletions(-) diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index 44f946f6..9a2d4bef 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -8,8 +8,6 @@ import ( "heckel.io/ntfy/v2/s3" ) -const deleteBatchSize = 1000 - type s3Backend struct { client *s3.Client } @@ -45,17 +43,7 @@ func (b *s3Backend) List() ([]object, error) { } func (b *s3Backend) Delete(ids ...string) error { - // S3 DeleteObjects supports up to 1000 keys per call - for i := 0; i < len(ids); i += deleteBatchSize { - end := i + deleteBatchSize - if end > len(ids) { - end = len(ids) - } - if err := b.client.DeleteObjects(context.Background(), ids[i:end]); err != nil { - return err - } - } - return nil + return b.client.DeleteObjects(context.Background(), ids) } func (b *s3Backend) DeleteIncomplete(cutoff time.Time) error { diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index a41c6f8b..6615f4e9 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -14,20 +14,20 @@ import ( ) func TestS3Store_WriteWithPrefix(t *testing.T) { - s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL") + s3URL := os.Getenv("NTFY_TEST_S3_URL") if s3URL == "" { - t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set") + t.Skip("NTFY_TEST_S3_URL not set") } cfg, err := s3.ParseURL(s3URL) require.Nil(t, err) cfg.Prefix = "test-prefix" client := s3.New(cfg) - deleteAllObjects(client) + deleteAllObjects(t, client) backend := newS3Backend(client) cache, err := newStore(backend, 10*1024, nil) require.Nil(t, err) t.Cleanup(func() { - deleteAllObjects(client) + deleteAllObjects(t, client) cache.Close() }) @@ -47,34 +47,46 @@ func TestS3Store_WriteWithPrefix(t *testing.T) { func newTestRealS3Store(t *testing.T, totalSizeLimit int64) (*Store, *modTimeOverrideBackend) { t.Helper() - s3URL := os.Getenv("NTFY_TEST_ATTACHMENT_S3_URL") + s3URL := os.Getenv("NTFY_TEST_S3_URL") if s3URL == "" { - t.Skip("NTFY_TEST_ATTACHMENT_S3_URL not set") + t.Skip("NTFY_TEST_S3_URL not set") } cfg, err := s3.ParseURL(s3URL) require.Nil(t, err) + if cfg.Prefix != "" { + cfg.Prefix = cfg.Prefix + "/testpkg-attachment" + } else { + cfg.Prefix = "testpkg-attachment" + } client := s3.New(cfg) inner := newS3Backend(client) wrapper := &modTimeOverrideBackend{backend: inner, modTimes: make(map[string]time.Time)} - deleteAllObjects(client) + deleteAllObjects(t, client) store, err := newStore(wrapper, totalSizeLimit, nil) require.Nil(t, err) t.Cleanup(func() { - deleteAllObjects(client) + deleteAllObjects(t, client) store.Close() }) return store, wrapper } -func deleteAllObjects(client *s3.Client) { - objects, _ := client.ListObjectsV2(context.Background()) - keys := make([]string, 0, len(objects)) - for _, obj := range objects { - keys = append(keys, obj.Key) - } - if len(keys) > 0 { - client.DeleteObjects(context.Background(), keys) //nolint:errcheck +func deleteAllObjects(t *testing.T, client *s3.Client) { + t.Helper() + for i := 0; i < 20; i++ { + objects, err := client.ListObjectsV2(context.Background()) + require.Nil(t, err) + if len(objects) == 0 { + return + } + keys := make([]string, len(objects)) + for j, obj := range objects { + keys[j] = obj.Key + } + require.Nil(t, client.DeleteObjects(context.Background(), keys)) + time.Sleep(200 * time.Millisecond) } + t.Fatal("timed out waiting for bucket to be empty") } // modTimeOverrideBackend wraps a backend and allows overriding LastModified times returned by List(). diff --git a/attachment/store_test.go b/attachment/store_test.go index 7ac7cddb..11d0b244 100644 --- a/attachment/store_test.go +++ b/attachment/store_test.go @@ -332,7 +332,7 @@ func TestStore_Sync_SkipsRecentFiles(t *testing.T) { // callback that makes a specific object's timestamp old enough for orphan cleanup (> 1 hour). // For the file backend, this uses os.Chtimes; for the S3 backend, it overrides the object's // LastModified time via a modTimeOverrideBackend wrapper. Objects start with recent timestamps -// by default. The S3 subtest is skipped if NTFY_TEST_ATTACHMENT_S3_URL is not set. +// by default. The S3 subtest is skipped if NTFY_TEST_S3_URL is not set. func forEachBackend(t *testing.T, totalSizeLimit int64, f func(t *testing.T, s *Store, makeOld func(string))) { t.Run("file", func(t *testing.T) { dir, s := newTestFileStore(t, totalSizeLimit) diff --git a/s3/client.go b/s3/client.go index d9ec1ab8..e06ff5c9 100644 --- a/s3/client.go +++ b/s3/client.go @@ -11,7 +11,6 @@ import ( "io" "net/http" "net/url" - "strconv" "strings" "time" @@ -125,7 +124,7 @@ func (c *Client) ListObjectsV2(ctx context.Context) ([]*Object, error) { var all []*Object var token string for page := 0; page < maxPages; page++ { - result, err := c.listObjectsV2(ctx, token, 0) + result, err := c.listObjectsV2(ctx, token) if err != nil { return nil, err } @@ -149,8 +148,7 @@ func (c *Client) ListObjectsV2(ctx context.Context) ([]*Object, error) { } // listObjectsV2 performs a single ListObjectsV2 request using the client's configured prefix. -// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000). -func (c *Client) listObjectsV2(ctx context.Context, continuationToken string, maxKeys int) (*listObjectsV2Result, error) { +func (c *Client) listObjectsV2(ctx context.Context, continuationToken string) (*listObjectsV2Result, error) { log.Tag(tagS3Client).Debug("Listing remote objects with continuation token '%s'", continuationToken) query := url.Values{"list-type": {"2"}} if prefix := c.config.ListPrefix(); prefix != "" { @@ -159,9 +157,6 @@ func (c *Client) listObjectsV2(ctx context.Context, continuationToken string, ma if continuationToken != "" { query.Set("continuation-token", continuationToken) } - if maxKeys > 0 { - query.Set("max-keys", strconv.Itoa(maxKeys)) - } respBody, err := c.do(ctx, "ListObjects", http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, nil) if err != nil { return nil, err @@ -182,6 +177,20 @@ func (c *Client) listObjectsV2(ctx context.Context, continuationToken string, ma // // See https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { + // S3 DeleteObjects supports up to 1000 keys per call + for i := 0; i < len(keys); i += maxDeleteBatchSize { + end := i + maxDeleteBatchSize + if end > len(keys) { + end = len(keys) + } + if err := c.deleteObjects(ctx, keys[i:end]); err != nil { + return err + } + } + return nil +} + +func (c *Client) deleteObjects(ctx context.Context, keys []string) error { log.Tag(tagS3Client).Debug("Deleting %d object(s)", len(keys)) req := &deleteObjectsRequest{ Quiet: true, diff --git a/s3/client_test.go b/s3/client_test.go index 652db3e7..f4b85089 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -3,13 +3,9 @@ package s3 import ( "bytes" "context" - "encoding/xml" "fmt" "io" - "net/http" - "net/http/httptest" "os" - "sort" "strings" "sync" "testing" @@ -18,271 +14,6 @@ import ( "github.com/stretchr/testify/require" ) -// --- Mock S3 server --- -// -// A minimal S3-compatible HTTP server that supports PutObject, GetObject, DeleteObjects, and -// ListObjectsV2. Uses path-style addressing: /{bucket}/{key}. Objects are stored in memory. - -type mockS3Server struct { - objects map[string][]byte // full key (bucket/key) -> body - uploads map[string]map[int][]byte // uploadID -> partNumber -> data - nextID int // counter for generating upload IDs - mu sync.RWMutex -} - -func newMockS3Server() (*httptest.Server, *mockS3Server) { - m := &mockS3Server{ - objects: make(map[string][]byte), - uploads: make(map[string]map[int][]byte), - } - return httptest.NewTLSServer(m), m -} - -func (m *mockS3Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Path is /{bucket}[/{key...}] - path := strings.TrimPrefix(r.URL.Path, "/") - q := r.URL.Query() - - switch { - case r.Method == http.MethodPut && q.Has("partNumber"): - m.handleUploadPart(w, r, path) - case r.Method == http.MethodPut: - m.handlePut(w, r, path) - case r.Method == http.MethodPost && q.Has("uploads"): - m.handleInitiateMultipart(w, r, path) - case r.Method == http.MethodPost && q.Has("uploadId"): - m.handleCompleteMultipart(w, r, path) - case r.Method == http.MethodDelete && q.Has("uploadId"): - m.handleAbortMultipart(w, r, path) - case r.Method == http.MethodGet && q.Get("list-type") == "2": - m.handleList(w, r, path) - case r.Method == http.MethodGet: - m.handleGet(w, r, path) - case r.Method == http.MethodPost && q.Has("delete"): - m.handleDelete(w, r, path) - default: - http.Error(w, "not implemented", http.StatusNotImplemented) - } -} - -func (m *mockS3Server) handlePut(w http.ResponseWriter, r *http.Request, path string) { - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - m.mu.Lock() - m.objects[path] = body - m.mu.Unlock() - w.WriteHeader(http.StatusOK) -} - -func (m *mockS3Server) handleInitiateMultipart(w http.ResponseWriter, r *http.Request, path string) { - m.mu.Lock() - m.nextID++ - uploadID := fmt.Sprintf("upload-%d", m.nextID) - m.uploads[uploadID] = make(map[int][]byte) - m.mu.Unlock() - - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `%s`, uploadID) -} - -func (m *mockS3Server) handleUploadPart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - var partNumber int - fmt.Sscanf(r.URL.Query().Get("partNumber"), "%d", &partNumber) - - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - m.mu.Lock() - parts, ok := m.uploads[uploadID] - if !ok { - m.mu.Unlock() - http.Error(w, "NoSuchUpload", http.StatusNotFound) - return - } - parts[partNumber] = body - m.mu.Unlock() - - etag := fmt.Sprintf(`"etag-part-%d"`, partNumber) - w.Header().Set("ETag", etag) - w.WriteHeader(http.StatusOK) -} - -func (m *mockS3Server) handleCompleteMultipart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - - m.mu.Lock() - parts, ok := m.uploads[uploadID] - if !ok { - m.mu.Unlock() - http.Error(w, "NoSuchUpload", http.StatusNotFound) - return - } - - // Assemble parts in order - var assembled []byte - for i := 1; i <= len(parts); i++ { - assembled = append(assembled, parts[i]...) - } - m.objects[path] = assembled - delete(m.uploads, uploadID) - m.mu.Unlock() - - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `%s`, path) -} - -func (m *mockS3Server) handleAbortMultipart(w http.ResponseWriter, r *http.Request, path string) { - uploadID := r.URL.Query().Get("uploadId") - m.mu.Lock() - delete(m.uploads, uploadID) - m.mu.Unlock() - w.WriteHeader(http.StatusNoContent) -} - -func (m *mockS3Server) handleGet(w http.ResponseWriter, r *http.Request, path string) { - m.mu.RLock() - body, ok := m.objects[path] - m.mu.RUnlock() - if !ok { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte(`NoSuchKeyThe specified key does not exist.`)) - return - } - w.Header().Set("Content-Length", fmt.Sprintf("%d", len(body))) - w.WriteHeader(http.StatusOK) - w.Write(body) -} - -type listObjectsResponse struct { - XMLName xml.Name `xml:"ListBucketResult"` - Contents []listObject `xml:"Contents"` - // Pagination support - IsTruncated bool `xml:"IsTruncated"` - NextContinuationToken string `xml:"NextContinuationToken"` -} - -func (m *mockS3Server) handleDelete(w http.ResponseWriter, r *http.Request, bucketPath string) { - // bucketPath is just the bucket name - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - var req struct { - Objects []struct { - Key string `xml:"Key"` - } `xml:"Object"` - } - if err := xml.Unmarshal(body, &req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - m.mu.Lock() - for _, obj := range req.Objects { - delete(m.objects, bucketPath+"/"+obj.Key) - } - m.mu.Unlock() - w.WriteHeader(http.StatusOK) - w.Write([]byte(``)) -} - -func (m *mockS3Server) handleList(w http.ResponseWriter, r *http.Request, bucketPath string) { - prefix := r.URL.Query().Get("prefix") - contToken := r.URL.Query().Get("continuation-token") - - m.mu.RLock() - var allKeys []string - for key := range m.objects { - objKey := strings.TrimPrefix(key, bucketPath+"/") - if objKey == key { - continue // different bucket - } - if prefix == "" || strings.HasPrefix(objKey, prefix) { - allKeys = append(allKeys, objKey) - } - } - m.mu.RUnlock() - sort.Strings(allKeys) - - // Simple continuation token: it's the key to start after - startIdx := 0 - if contToken != "" { - for i, k := range allKeys { - if k == contToken { - startIdx = i + 1 - break - } - } - } - - maxKeys := 1000 - if mk := r.URL.Query().Get("max-keys"); mk != "" { - fmt.Sscanf(mk, "%d", &maxKeys) - } - - endIdx := startIdx + maxKeys - truncated := false - nextToken := "" - if endIdx < len(allKeys) { - truncated = true - nextToken = allKeys[endIdx-1] - allKeys = allKeys[startIdx:endIdx] - } else { - allKeys = allKeys[startIdx:] - } - - m.mu.RLock() - var contents []listObject - for _, objKey := range allKeys { - body := m.objects[bucketPath+"/"+objKey] - contents = append(contents, listObject{Key: objKey, Size: int64(len(body)), LastModified: time.Now().Format(time.RFC3339)}) - } - m.mu.RUnlock() - - resp := listObjectsResponse{ - Contents: contents, - IsTruncated: truncated, - NextContinuationToken: nextToken, - } - w.Header().Set("Content-Type", "application/xml") - w.WriteHeader(http.StatusOK) - xml.NewEncoder(w).Encode(resp) -} - -func (m *mockS3Server) objectCount() int { - m.mu.RLock() - defer m.mu.RUnlock() - return len(m.objects) -} - -// --- Helper to create a test client pointing at mock server --- - -func newTestClient(server *httptest.Server, bucket, prefix string) *Client { - // httptest.NewTLSServer URL is like "https://127.0.0.1:PORT" - host := strings.TrimPrefix(server.URL, "https://") - return New(&Config{ - AccessKey: "AKID", - SecretKey: "SECRET", - Region: "us-east-1", - Endpoint: host, - Bucket: bucket, - Prefix: prefix, - PathStyle: true, - HTTPClient: server.Client(), - }) -} - -// --- URL parsing tests --- - func TestParseURL_Success(t *testing.T) { cfg, err := ParseURL("s3://AKID:SECRET@my-bucket/attachments?region=us-east-1") require.Nil(t, err) @@ -409,13 +140,10 @@ func TestConfig_ListPrefix(t *testing.T) { require.Equal(t, "", c2.ListPrefix()) } -// --- Integration tests using mock S3 server --- +// --- Integration tests using real S3 --- func TestClient_PutGetObject(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "") - + client := newTestClient(t) ctx := context.Background() // Put @@ -432,152 +160,85 @@ func TestClient_PutGetObject(t *testing.T) { require.Equal(t, "hello world", string(data)) } -func TestClient_PutGetObject_WithPrefix(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "pfx") - - ctx := context.Background() - - err := client.PutObject(ctx, "test-key", strings.NewReader("hello"), 0) - require.Nil(t, err) - - reader, _, err := client.GetObject(ctx, "test-key") - require.Nil(t, err) - data, _ := io.ReadAll(reader) - reader.Close() - require.Equal(t, "hello", string(data)) -} - func TestClient_GetObject_NotFound(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "") + client := newTestClient(t) _, _, err := client.GetObject(context.Background(), "nonexistent") require.Error(t, err) - var errResp *errorResponse - require.ErrorAs(t, err, &errResp) - require.Equal(t, 404, errResp.StatusCode) - require.Equal(t, "NoSuchKey", errResp.Code) } func TestClient_DeleteObjects(t *testing.T) { - server, mock := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "") - + client := newTestClient(t) ctx := context.Background() // Put several objects for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%d", i), bytes.NewReader([]byte("data")), 0) + err := client.PutObject(ctx, fmt.Sprintf("del-%d", i), bytes.NewReader([]byte("data")), 0) require.Nil(t, err) } - require.Equal(t, 5, mock.objectCount()) + waitForCount(t, client, 5) // Delete some - err := client.DeleteObjects(ctx, []string{"key-1", "key-3"}) + err := client.DeleteObjects(ctx, []string{"del-1", "del-3"}) require.Nil(t, err) - require.Equal(t, 3, mock.objectCount()) + waitForCount(t, client, 3) // Verify deleted ones are gone - _, _, err = client.GetObject(ctx, "key-1") + _, _, err = client.GetObject(ctx, "del-1") require.Error(t, err) - _, _, err = client.GetObject(ctx, "key-3") + _, _, err = client.GetObject(ctx, "del-3") require.Error(t, err) // Verify remaining ones are still there - reader, _, err := client.GetObject(ctx, "key-0") - require.Nil(t, err) - reader.Close() + for _, key := range []string{"del-0", "del-2", "del-4"} { + reader, _, err := client.GetObject(ctx, key) + require.Nil(t, err) + reader.Close() + } } func TestClient_ListObjects(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - + client := newTestClient(t) ctx := context.Background() - // Client with prefix "pfx": list should only return objects under pfx/ - client := newTestClient(server, "my-bucket", "pfx") for i := 0; i < 3; i++ { - err := client.PutObject(ctx, fmt.Sprintf("%d", i), bytes.NewReader([]byte("x")), 0) + err := client.PutObject(ctx, fmt.Sprintf("list-%d", i), bytes.NewReader([]byte("x")), 0) require.Nil(t, err) } - - // Also put an object outside the prefix using a no-prefix client - clientNoPrefix := newTestClient(server, "my-bucket", "") - err := clientNoPrefix.PutObject(ctx, "other", bytes.NewReader([]byte("y")), 0) - require.Nil(t, err) - - // List with prefix client: should only see 3 - result, err := client.listObjectsV2(ctx, "", 0) - require.Nil(t, err) - require.Len(t, result.Contents, 3) - require.False(t, result.IsTruncated) - - // List with no-prefix client: should see all 4 - result, err = clientNoPrefix.listObjectsV2(ctx, "", 0) - require.Nil(t, err) - require.Len(t, result.Contents, 4) + waitForCount(t, client, 3) } func TestClient_ListObjects_Pagination(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "") - + client := newTestClient(t) ctx := context.Background() - // Put 5 objects - for i := 0; i < 5; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 0) + // Create 1010 objects in parallel (5 goroutines) + const total = 1010 + const workers = 5 + var wg sync.WaitGroup + errs := make(chan error, total) + for w := 0; w < workers; w++ { + wg.Add(1) + go func(start int) { + defer wg.Done() + for i := start; i < total; i += workers { + if err := client.PutObject(ctx, fmt.Sprintf("pg-%04d", i), bytes.NewReader([]byte("x")), 0); err != nil { + errs <- err + return + } + } + }(w) + } + wg.Wait() + close(errs) + for err := range errs { require.Nil(t, err) } - - // List with max-keys=2 - result, err := client.listObjectsV2(ctx, "", 2) - require.Nil(t, err) - require.Len(t, result.Contents, 2) - require.True(t, result.IsTruncated) - require.NotEmpty(t, result.NextContinuationToken) - - // Get next page - result2, err := client.listObjectsV2(ctx, result.NextContinuationToken, 2) - require.Nil(t, err) - require.Len(t, result2.Contents, 2) - require.True(t, result2.IsTruncated) - - // Get last page - result3, err := client.listObjectsV2(ctx, result2.NextContinuationToken, 2) - require.Nil(t, err) - require.Len(t, result3.Contents, 1) - require.False(t, result3.IsTruncated) -} - -func TestClient_ListAllObjects(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "pfx") - - ctx := context.Background() - - for i := 0; i < 10; i++ { - err := client.PutObject(ctx, fmt.Sprintf("key-%02d", i), bytes.NewReader([]byte("x")), 0) - require.Nil(t, err) - } - - objects, err := client.ListObjectsV2(ctx) - require.Nil(t, err) - require.Len(t, objects, 10) + waitForCount(t, client, total) } func TestClient_PutObject_LargeBody(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "") - + client := newTestClient(t) ctx := context.Background() // 1 MB object @@ -598,10 +259,7 @@ func TestClient_PutObject_LargeBody(t *testing.T) { } func TestClient_PutObject_ChunkedUpload(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "") - + client := newTestClient(t) ctx := context.Background() // 12 MB object, exceeds 5 MB partSize, triggers multipart upload path @@ -622,10 +280,7 @@ func TestClient_PutObject_ChunkedUpload(t *testing.T) { } func TestClient_PutObject_ExactPartSize(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "") - + client := newTestClient(t) ctx := context.Background() // Exactly 5 MB (partSize), should use the simple put path (ReadFull succeeds fully) @@ -646,10 +301,7 @@ func TestClient_PutObject_ExactPartSize(t *testing.T) { } func TestClient_PutObject_StreamingExactLength(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "pfx") - + client := newTestClient(t) ctx := context.Background() // untrustedLength matches body exactly — streams directly via putObject @@ -666,10 +318,7 @@ func TestClient_PutObject_StreamingExactLength(t *testing.T) { } func TestClient_PutObject_StreamingBodyLongerThanClaimed(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "pfx") - + client := newTestClient(t) ctx := context.Background() // Body has 11 bytes, but we claim 5 — only first 5 bytes should be stored @@ -686,16 +335,12 @@ func TestClient_PutObject_StreamingBodyLongerThanClaimed(t *testing.T) { } func TestClient_PutObject_StreamingBodyShorterThanClaimed(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "pfx") - + client := newTestClient(t) ctx := context.Background() // Body has 5 bytes, but we claim 100 — should fail err := client.PutObject(ctx, "stream-short", strings.NewReader("hello"), 100) require.Error(t, err) - require.Contains(t, err.Error(), "ContentLength") // Object should not exist _, _, err = client.GetObject(ctx, "stream-short") @@ -703,10 +348,7 @@ func TestClient_PutObject_StreamingBodyShorterThanClaimed(t *testing.T) { } func TestClient_PutObject_NestedKey(t *testing.T) { - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "") - + client := newTestClient(t) ctx := context.Background() err := client.PutObject(ctx, "deep/nested/prefix/file.txt", strings.NewReader("nested"), 0) @@ -719,199 +361,54 @@ func TestClient_PutObject_NestedKey(t *testing.T) { require.Equal(t, "nested", string(data)) } -// --- Scale test: 20k objects (ntfy-adjacent) --- - -func TestClient_ListAllObjects_20k(t *testing.T) { - if testing.Short() { - t.Skip("skipping 20k object test in short mode") +func newTestClient(t *testing.T) *Client { + t.Helper() + s3URL := os.Getenv("NTFY_TEST_S3_URL") + if s3URL == "" { + t.Skip("NTFY_TEST_S3_URL not set") } - - server, _ := newMockS3Server() - defer server.Close() - client := newTestClient(server, "my-bucket", "attachments") - - ctx := context.Background() - const numObjects = 20000 - const batchSize = 500 - - // Insert 20k objects in batches to keep it fast - for batch := 0; batch < numObjects/batchSize; batch++ { - for i := 0; i < batchSize; i++ { - idx := batch*batchSize + i - key := fmt.Sprintf("%08d", idx) - err := client.PutObject(ctx, key, bytes.NewReader([]byte("x")), 0) - require.Nil(t, err) - } - } - - // List all 20k objects with pagination - objects, err := client.ListObjectsV2(ctx) + cfg, err := ParseURL(s3URL) require.Nil(t, err) - require.Len(t, objects, numObjects) - - // Verify total size - var totalSize int64 - for _, obj := range objects { - totalSize += obj.Size + // Use per-test prefix to isolate objects between tests + if cfg.Prefix != "" { + cfg.Prefix = cfg.Prefix + "/testpkg-s3/" + t.Name() + } else { + cfg.Prefix = "testpkg-s3/" + t.Name() } - require.Equal(t, int64(numObjects), totalSize) - - // Delete 1000 objects (simulating attachment expiry cleanup) - keys := make([]string, 1000) - for i := range keys { - keys[i] = fmt.Sprintf("%08d", i) - } - err = client.DeleteObjects(ctx, keys) - require.Nil(t, err) - - // List again: should have 19000 - objects, err = client.ListObjectsV2(ctx) - require.Nil(t, err) - require.Len(t, objects, numObjects-1000) + client := New(cfg) + deleteAllObjects(t, client) + t.Cleanup(func() { deleteAllObjects(t, client) }) + return client } -// --- Real S3 integration test --- -// -// Set the following environment variables to run this test against a real S3 bucket: -// -// S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET -// -// Optional: -// -// S3_ENDPOINT: host[:port] for S3-compatible providers (e.g. "nyc3.digitaloceanspaces.com") -// S3_PATH_STYLE: set to "true" for path-style addressing -// S3_PREFIX: key prefix to use (default: "ntfy-s3-test") -func TestClient_RealBucket(t *testing.T) { - accessKey := os.Getenv("S3_ACCESS_KEY") - secretKey := os.Getenv("S3_SECRET_KEY") - region := os.Getenv("S3_REGION") - bucket := os.Getenv("S3_BUCKET") - - if accessKey == "" || secretKey == "" || region == "" || bucket == "" { - t.Skip("skipping real S3 test: set S3_ACCESS_KEY, S3_SECRET_KEY, S3_REGION, S3_BUCKET") +func deleteAllObjects(t *testing.T, client *Client) { + t.Helper() + for i := 0; i < 20; i++ { + objects, err := client.ListObjectsV2(context.Background()) + require.Nil(t, err) + if len(objects) == 0 { + return + } + keys := make([]string, len(objects)) + for j, obj := range objects { + keys[j] = obj.Key + } + require.Nil(t, client.DeleteObjects(context.Background(), keys)) + time.Sleep(200 * time.Millisecond) } - - endpoint := os.Getenv("S3_ENDPOINT") - if endpoint == "" { - endpoint = fmt.Sprintf("s3.%s.amazonaws.com", region) - } - pathStyle := os.Getenv("S3_PATH_STYLE") == "true" - prefix := os.Getenv("S3_PREFIX") - if prefix == "" { - prefix = "ntfy-s3-test" - } - - client := New(&Config{ - AccessKey: accessKey, - SecretKey: secretKey, - Region: region, - Endpoint: endpoint, - Bucket: bucket, - Prefix: prefix, - PathStyle: pathStyle, - }) - - ctx := context.Background() - - // Clean up any leftover objects from previous runs - existing, err := client.ListObjectsV2(ctx) - require.Nil(t, err) - if len(existing) > 0 { - keys := make([]string, len(existing)) - for i, obj := range existing { - keys[i] = obj.Key - } - // Batch delete in groups of 1000 - for i := 0; i < len(keys); i += 1000 { - end := i + 1000 - if end > len(keys) { - end = len(keys) - } - err := client.DeleteObjects(ctx, keys[i:end]) - require.Nil(t, err) - } - } - - t.Run("PutGetDelete", func(t *testing.T) { - key := "test-object" - content := "hello from ntfy s3 test" - - // Put - err := client.PutObject(ctx, key, strings.NewReader(content), 0) - require.Nil(t, err) - - // Get - reader, size, err := client.GetObject(ctx, key) - require.Nil(t, err) - require.Equal(t, int64(len(content)), size) - data, err := io.ReadAll(reader) - reader.Close() - require.Nil(t, err) - require.Equal(t, content, string(data)) - - // Delete - err = client.DeleteObjects(ctx, []string{key}) - require.Nil(t, err) - - // Get after delete should fail - _, _, err = client.GetObject(ctx, key) - require.Error(t, err) - var errResp *errorResponse - require.ErrorAs(t, err, &errResp) - require.Equal(t, 404, errResp.StatusCode) - }) - - t.Run("ListObjects", func(t *testing.T) { - // Use a sub-prefix client for isolation - listClient := New(&Config{ - AccessKey: accessKey, - SecretKey: secretKey, - Region: region, - Endpoint: endpoint, - Bucket: bucket, - Prefix: prefix + "/list-test", - PathStyle: pathStyle, - }) - - // Put 10 objects - for i := 0; i < 10; i++ { - err := listClient.PutObject(ctx, fmt.Sprintf("%d", i), strings.NewReader("x"), 0) - require.Nil(t, err) - } - - // List - objects, err := listClient.ListObjectsV2(ctx) - require.Nil(t, err) - require.Len(t, objects, 10) - - // Clean up - keys := make([]string, 10) - for i := range keys { - keys[i] = fmt.Sprintf("%d", i) - } - err = listClient.DeleteObjects(ctx, keys) - require.Nil(t, err) - }) - - t.Run("LargeObject", func(t *testing.T) { - key := "large-object" - data := make([]byte, 5*1024*1024) // 5 MB - for i := range data { - data[i] = byte(i % 256) - } - - err := client.PutObject(ctx, key, bytes.NewReader(data), 0) - require.Nil(t, err) - - reader, size, err := client.GetObject(ctx, key) - require.Nil(t, err) - require.Equal(t, int64(len(data)), size) - got, err := io.ReadAll(reader) - reader.Close() - require.Nil(t, err) - require.Equal(t, data, got) - - err = client.DeleteObjects(ctx, []string{key}) - require.Nil(t, err) - }) + t.Fatal("timed out waiting for bucket to be empty") +} + +func waitForCount(t *testing.T, client *Client, expected int) { + t.Helper() + for i := 0; i < 20; i++ { + objects, err := client.ListObjectsV2(context.Background()) + require.Nil(t, err) + if len(objects) == expected { + return + } + time.Sleep(200 * time.Millisecond) + } + objects, _ := client.ListObjectsV2(context.Background()) + t.Fatalf("timed out waiting for %d objects, got %d", expected, len(objects)) } diff --git a/s3/util.go b/s3/util.go index 1f4c2dd9..ae692735 100644 --- a/s3/util.go +++ b/s3/util.go @@ -34,6 +34,9 @@ const ( // maxPages is the max number of pages to iterate through when listing objects maxPages = 500 + + // maxDeleteBatchSize is the maximum number of keys per S3 DeleteObjects call + maxDeleteBatchSize = 1000 ) // ParseURL parses an S3 URL of the form: From f8397838e625768516ce7ac5967c581ef53deb55 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 16:05:48 -0400 Subject: [PATCH 25/32] Update docs --- cmd/serve.go | 2 +- docs/config.md | 180 +++++++++++++++++++++++++---------------------- docs/releases.md | 2 +- 3 files changed, 97 insertions(+), 87 deletions(-) diff --git a/cmd/serve.go b/cmd/serve.go index 52794a07..2af1f389 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -52,7 +52,7 @@ var flagsServe = append( altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-users", Aliases: []string{"auth_users"}, EnvVars: []string{"NTFY_AUTH_USERS"}, Usage: "pre-provisioned declarative users"}), altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}), altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}), - altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files, or S3 URL (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION)"}), + altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files, or S3 URL (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT])"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}), altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-expiry-duration", Aliases: []string{"attachment_expiry_duration", "X"}, EnvVars: []string{"NTFY_ATTACHMENT_EXPIRY_DURATION"}, Value: util.FormatDuration(server.DefaultAttachmentExpiryDuration), Usage: "duration after which uploaded attachments will be deleted (e.g. 3h, 20h)"}), diff --git a/docs/config.md b/docs/config.md index ae7547b3..de534241 100644 --- a/docs/config.md +++ b/docs/config.md @@ -490,12 +490,14 @@ Subscribers can retrieve cached messaging using the [`poll=1` parameter](subscri ## Attachments If desired, you may allow users to upload and [attach files to notifications](publish.md#attachments). To enable this feature, you have to configure an attachment storage backend and a base URL (`base-url`). Attachments can be stored -either on the local filesystem or in an S3-compatible object store, both using the `attachment-cache-dir` option. +either on the [local filesystem](#filesystem-storage) or in an [S3-compatible object store](#s3-storage), both using the `attachment-cache-dir` option. Once configured, you can upload attachments via PUT. By default, attachments are stored **for only 3 hours**. The main reason for this is to avoid legal issues and such when hosting user controlled content. Typically, this is more than enough time for the user (or the auto download -feature) to download the file. The following config options are relevant to attachments: +feature) to download the file. You can increase this time by [purchasing ntfy Pro](https://ntfy.sh/app) via the web app. + +The following config options are relevant to attachments: * `base-url` is the root URL for the ntfy server; this is needed for the generated attachment URLs * `attachment-cache-dir` is the cache directory for attached files, or an S3 URL for object storage @@ -503,6 +505,13 @@ feature) to download the file. The following config options are relevant to atta * `attachment-file-size-limit` is the per-file attachment size limit (e.g. 300k, 2M, 100M, default: 15M) * `attachment-expiry-duration` is the duration after which uploaded attachments will be deleted (e.g. 3h, 20h, default: 3h) +!!! warning + ntfy takes full control over the attachment directory or S3 bucket. Files that match the message ID format without + entries in the message table will be deleted. **Do not use a directory or S3 bucket that is also used for something else.** + +Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit` +and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse. + ### Filesystem storage Here's an example config using the local filesystem for attachment storage: @@ -538,22 +547,23 @@ When `endpoint` is specified, path-style addressing is enabled automatically (us === "/etc/ntfy/server.yml (AWS S3)" ``` yaml base-url: "https://ntfy.sh" - attachment-cache-dir: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1" + attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=us-east-1" + ``` + +=== "/etc/ntfy/server.yml (DigitalOcean Spaces)" + ``` yaml + base-url: "https://ntfy.sh" + attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=nyc3&endpoint=https://nyc3.digitaloceanspaces.com" ``` === "/etc/ntfy/server.yml (custom endpoint)" ``` yaml base-url: "https://ntfy.sh" - attachment-cache-dir: "s3://AKID:SECRET@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com" + attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com" ``` -**Cleanup behavior:** A background sync runs every 15 minutes to reconcile the S3 bucket (or configured prefix) with -the server's message database. Objects whose keys match attachment file IDs that are no longer referenced in the database -(and are older than 1 hour) are automatically deleted. This also cleans up incomplete S3 multipart uploads that were -abandoned due to interrupted or failed attachment uploads. - -Please also refer to the [rate limiting](#rate-limiting) settings below, specifically `visitor-attachment-total-size-limit` -and `visitor-attachment-daily-bandwidth-limit`. Setting these conservatively is necessary to avoid abuse. +Note that the access key and secret key may have to be URL encoded. For instance, a secret key `YmxhY+mxhYmxhC` (note the `+`) should +be encoded as `YmxhY%2BmxhYmxhC` (note the `%2B`), so the URL would be `s3://ACCESS_KEY:YmxhY%2BmxhYmxhC@my-bucket/attachments...`. ## Access control By default, the ntfy server is open for everyone, meaning **everyone can read and write to any topic** (this is how @@ -2125,80 +2135,80 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`). `cache_duration` and `cache-duration` are both supported. This is to support stricter YAML parsers that do not support dashes. -| Config option | Env variable | Format | Default | Description | -|--------------------------------------------|-------------------------------------------------|-----------------------------------------------------|-------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `base-url` | `NTFY_BASE_URL` | *URL* | - | Public facing base URL of the service (e.g. `https://ntfy.sh`) | -| `listen-http` | `NTFY_LISTEN_HTTP` | `[host]:port` | `:80` | Listen address for the HTTP web server | -| `listen-https` | `NTFY_LISTEN_HTTPS` | `[host]:port` | - | Listen address for the HTTPS web server. If set, you also need to set `key-file` and `cert-file`. | -| `listen-unix` | `NTFY_LISTEN_UNIX` | *filename* | - | Path to a Unix socket to listen on | -| `listen-unix-mode` | `NTFY_LISTEN_UNIX_MODE` | *file mode* | *system default* | File mode of the Unix socket, e.g. 0700 or 0777 | -| `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. | -| `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. | -| `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). | +| Config option | Env variable | Format | Default | Description | +|--------------------------------------------|-------------------------------------------------|-----------------------------------------------------|-------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `base-url` | `NTFY_BASE_URL` | *URL* | - | Public facing base URL of the service (e.g. `https://ntfy.sh`) | +| `listen-http` | `NTFY_LISTEN_HTTP` | `[host]:port` | `:80` | Listen address for the HTTP web server | +| `listen-https` | `NTFY_LISTEN_HTTPS` | `[host]:port` | - | Listen address for the HTTPS web server. If set, you also need to set `key-file` and `cert-file`. | +| `listen-unix` | `NTFY_LISTEN_UNIX` | *filename* | - | Path to a Unix socket to listen on | +| `listen-unix-mode` | `NTFY_LISTEN_UNIX_MODE` | *file mode* | *system default* | File mode of the Unix socket, e.g. 0700 or 0777 | +| `key-file` | `NTFY_KEY_FILE` | *filename* | - | HTTPS/TLS private key file, only used if `listen-https` is set. | +| `cert-file` | `NTFY_CERT_FILE` | *filename* | - | HTTPS/TLS certificate file, only used if `listen-https` is set. | +| `firebase-key-file` | `NTFY_FIREBASE_KEY_FILE` | *filename* | - | If set, also publish messages to a Firebase Cloud Messaging (FCM) topic for your app. This is optional and only required to save battery when using the Android app. See [Firebase (FCM)](#firebase-fcm). | | `database-url` | `NTFY_DATABASE_URL` | *string (connection URL)* | - | PostgreSQL connection string (e.g. `postgres://user:pass@host:5432/ntfy`). If set, uses PostgreSQL for all database-backed stores (message cache, user manager, web push) instead of SQLite. See [database options](#database-options). | -| `database-replica-urls` | `NTFY_DATABASE_REPLICA_URLS` | *list of strings (connection URLs)* | - | PostgreSQL read replica connection strings. Non-critical read-only queries are distributed across replicas (round-robin) with automatic fallback to primary. Requires `database-url`. See [read replicas](#read-replicas). | -| `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). | -| `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. | -| `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) | -| `cache-batch-size` | `NTFY_CACHE_BATCH_SIZE` | *int* | 0 | Max size of messages to batch together when writing to message cache (if zero, writes are synchronous) | -| `cache-batch-timeout` | `NTFY_CACHE_BATCH_TIMEOUT` | *duration* | 0s | Timeout for batched async writes to the message cache (if zero, writes are synchronous) | -| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control (SQLite). If set, enables authentication and access control. Not required if `database-url` is set. See [access control](#access-control). | -| `auth-default-access` | `NTFY_AUTH_DEFAULT_ACCESS` | `read-write`, `read-only`, `write-only`, `deny-all` | `read-write` | Default permissions if no matching entries in the auth database are found. Default is `read-write`. | -| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) | -| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) | -| `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header | -| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory or S3 URL* | - | Cache directory for attached files, or S3 URL for object storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION`). | -| `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. | -| `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. | -| `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. | -| `smtp-sender-addr` | `NTFY_SMTP_SENDER_ADDR` | `host:port` | - | SMTP server address to allow email sending | -| `smtp-sender-user` | `NTFY_SMTP_SENDER_USER` | *string* | - | SMTP user; only used if e-mail sending is enabled | -| `smtp-sender-pass` | `NTFY_SMTP_SENDER_PASS` | *string* | - | SMTP password; only used if e-mail sending is enabled | -| `smtp-sender-from` | `NTFY_SMTP_SENDER_FROM` | *e-mail address* | - | SMTP sender e-mail address; only used if e-mail sending is enabled | -| `smtp-server-listen` | `NTFY_SMTP_SERVER_LISTEN` | `[ip]:port` | - | Defines the IP address and port the SMTP server will listen on, e.g. `:25` or `1.2.3.4:25` | -| `smtp-server-domain` | `NTFY_SMTP_SERVER_DOMAIN` | *domain name* | - | SMTP server e-mail domain, e.g. `ntfy.sh` | -| `smtp-server-addr-prefix` | `NTFY_SMTP_SERVER_ADDR_PREFIX` | *string* | - | Optional prefix for the e-mail addresses to prevent spam, e.g. `ntfy-` | -| `twilio-account` | `NTFY_TWILIO_ACCOUNT` | *string* | - | Twilio account SID, e.g. AC12345beefbeef67890beefbeef122586 | -| `twilio-auth-token` | `NTFY_TWILIO_AUTH_TOKEN` | *string* | - | Twilio auth token, e.g. affebeef258625862586258625862586 | -| `twilio-phone-number` | `NTFY_TWILIO_PHONE_NUMBER` | *string* | - | Twilio outgoing phone number, e.g. +18775132586 | -| `twilio-verify-service` | `NTFY_TWILIO_VERIFY_SERVICE` | *string* | - | Twilio Verify service SID, e.g. VA12345beefbeef67890beefbeef122586 | -| `keepalive-interval` | `NTFY_KEEPALIVE_INTERVAL` | *duration* | 45s | Interval in which keepalive messages are sent to the client. This is to prevent intermediaries closing the connection for inactivity. Note that the Android app has a hardcoded timeout at 77s, so it should be less than that. | -| `manager-interval` | `NTFY_MANAGER_INTERVAL` | *duration* | 1m | Interval in which the manager prunes old messages, deletes topics and prints the stats. | -| `message-size-limit` | `NTFY_MESSAGE_SIZE_LIMIT` | *size* | 4K | The size limit for the message body. Please note that this is largely untested, and that FCM/APNS have limits around 4KB. If you increase this size limit, FCM and APNS will NOT work for large messages. | -| `message-delay-limit` | `NTFY_MESSAGE_DELAY_LIMIT` | *duration* | 3d | Amount of time a message can be [scheduled](publish.md#scheduled-delivery) into the future when using the `Delay` header | -| `global-topic-limit` | `NTFY_GLOBAL_TOPIC_LIMIT` | *number* | 15,000 | Rate limiting: Total number of topics before the server rejects new topics. | -| `upstream-base-url` | `NTFY_UPSTREAM_BASE_URL` | *URL* | `https://ntfy.sh` | Forward poll request to an upstream server, this is needed for iOS push notifications for self-hosted servers | -| `upstream-access-token` | `NTFY_UPSTREAM_ACCESS_TOKEN` | *string* | `tk_zyYLYj...` | Access token to use for the upstream server; needed only if upstream rate limits are exceeded or upstream server requires auth | -| `visitor-attachment-total-size-limit` | `NTFY_VISITOR_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 100M | Rate limiting: Total storage limit used for attachments per visitor, for all attachments combined. Storage is freed after attachments expire. See `attachment-expiry-duration`. | -| `visitor-attachment-daily-bandwidth-limit` | `NTFY_VISITOR_ATTACHMENT_DAILY_BANDWIDTH_LIMIT` | *size* | 500M | Rate limiting: Total daily attachment download/upload traffic limit per visitor. This is to protect your bandwidth costs from exploding. | -| `visitor-email-limit-burst` | `NTFY_VISITOR_EMAIL_LIMIT_BURST` | *number* | 16 | Rate limiting:Initial limit of e-mails per visitor | -| `visitor-email-limit-replenish` | `NTFY_VISITOR_EMAIL_LIMIT_REPLENISH` | *duration* | 1h | Rate limiting: Strongly related to `visitor-email-limit-burst`: The rate at which the bucket is refilled | -| `visitor-message-daily-limit` | `NTFY_VISITOR_MESSAGE_DAILY_LIMIT` | *number* | - | Rate limiting: Allowed number of messages per day per visitor, reset every day at midnight (UTC). By default, this value is unset. | -| `visitor-request-limit-burst` | `NTFY_VISITOR_REQUEST_LIMIT_BURST` | *number* | 60 | Rate limiting: Allowed GET/PUT/POST requests per second, per visitor. This setting is the initial bucket of requests each visitor has | -| `visitor-request-limit-replenish` | `NTFY_VISITOR_REQUEST_LIMIT_REPLENISH` | *duration* | 5s | Rate limiting: Strongly related to `visitor-request-limit-burst`: The rate at which the bucket is refilled | -| `visitor-request-limit-exempt-hosts` | `NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS` | *comma-separated host/IP/CIDR list* | - | Rate limiting: List of hostnames and IPs to be exempt from request rate limiting | -| `visitor-subscription-limit` | `NTFY_VISITOR_SUBSCRIPTION_LIMIT` | *number* | 30 | Rate limiting: Number of subscriptions per visitor (IP address) | -| `visitor-subscriber-rate-limiting` | `NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING` | *bool* | `false` | Rate limiting: Enables subscriber-based rate limiting | -| `visitor-prefix-bits-ipv4` | `NTFY_VISITOR_PREFIX_BITS_IPV4` | *number* | 32 | Rate limiting: Number of bits to use for IPv4 visitor prefix, e.g. 24 for /24 | -| `visitor-prefix-bits-ipv6` | `NTFY_VISITOR_PREFIX_BITS_IPV6` | *number* | 64 | Rate limiting: Number of bits to use for IPv6 visitor prefix, e.g. 48 for /48 | -| `web-root` | `NTFY_WEB_ROOT` | *path*, e.g. `/` or `/app`, or `disable` | `/` | Sets root of the web app (e.g. /, or /app), or disables it entirely (disable) | -| `enable-signup` | `NTFY_ENABLE_SIGNUP` | *boolean* (`true` or `false`) | `false` | Allows users to sign up via the web app, or API | -| `enable-login` | `NTFY_ENABLE_LOGIN` | *boolean* (`true` or `false`) | `false` | Allows users to log in via the web app, or API | -| `enable-reservations` | `NTFY_ENABLE_RESERVATIONS` | *boolean* (`true` or `false`) | `false` | Allows users to reserve topics (if their tier allows it) | -| `require-login` | `NTFY_REQUIRE_LOGIN` | *boolean* (`true` or `false`) | `false` | All actions via the web app require a login | -| `stripe-secret-key` | `NTFY_STRIPE_SECRET_KEY` | *string* | - | Payments: Key used for the Stripe API communication, this enables payments | -| `stripe-webhook-key` | `NTFY_STRIPE_WEBHOOK_KEY` | *string* | - | Payments: Key required to validate the authenticity of incoming webhooks from Stripe | -| `billing-contact` | `NTFY_BILLING_CONTACT` | *email address* or *website* | - | Payments: Email or website displayed in Upgrade dialog as a billing contact | -| `web-push-public-key` | `NTFY_WEB_PUSH_PUBLIC_KEY` | *string* | - | Web Push: Public Key. Run `ntfy webpush keys` to generate | -| `web-push-private-key` | `NTFY_WEB_PUSH_PRIVATE_KEY` | *string* | - | Web Push: Private Key. Run `ntfy webpush keys` to generate | -| `web-push-file` | `NTFY_WEB_PUSH_FILE` | *string* | - | Web Push: Database file that stores subscriptions | -| `web-push-email-address` | `NTFY_WEB_PUSH_EMAIL_ADDRESS` | *string* | - | Web Push: Sender email address | -| `web-push-startup-queries` | `NTFY_WEB_PUSH_STARTUP_QUERIES` | *string* | - | Web Push: SQL queries to run against subscription database at startup | -| `web-push-expiry-duration` | `NTFY_WEB_PUSH_EXPIRY_DURATION` | *duration* | 60d | Web Push: Duration after which a subscription is considered stale and will be deleted. This is to prevent stale subscriptions. | -| `web-push-expiry-warning-duration` | `NTFY_WEB_PUSH_EXPIRY_WARNING_DURATION` | *duration* | 55d | Web Push: Duration after which a warning is sent to subscribers that their subscription will expire soon. This is to prevent stale subscriptions. | -| `log-format` | `NTFY_LOG_FORMAT` | *string* | `text` | Defines the output format, can be text or json | -| `log-file` | `NTFY_LOG_FILE` | *string* | - | Defines the filename to write logs to. If this is not set, ntfy logs to stderr | -| `log-level` | `NTFY_LOG_LEVEL` | *string* | `info` | Defines the default log level, can be one of trace, debug, info, warn or error | +| `database-replica-urls` | `NTFY_DATABASE_REPLICA_URLS` | *list of strings (connection URLs)* | - | PostgreSQL read replica connection strings. Non-critical read-only queries are distributed across replicas (round-robin) with automatic fallback to primary. Requires `database-url`. | +| `cache-file` | `NTFY_CACHE_FILE` | *filename* | - | If set, messages are cached in a local SQLite database instead of only in-memory. This allows for service restarts without losing messages in support of the since= parameter. See [message cache](#message-cache). | +| `cache-duration` | `NTFY_CACHE_DURATION` | *duration* | 12h | Duration for which messages will be buffered before they are deleted. This is required to support the `since=...` and `poll=1` parameter. Set this to `0` to disable the cache entirely. | +| `cache-startup-queries` | `NTFY_CACHE_STARTUP_QUERIES` | *string (SQL queries)* | - | SQL queries to run during database startup; this is useful for tuning and [enabling WAL mode](#message-cache) | +| `cache-batch-size` | `NTFY_CACHE_BATCH_SIZE` | *int* | 0 | Max size of messages to batch together when writing to message cache (if zero, writes are synchronous) | +| `cache-batch-timeout` | `NTFY_CACHE_BATCH_TIMEOUT` | *duration* | 0s | Timeout for batched async writes to the message cache (if zero, writes are synchronous) | +| `auth-file` | `NTFY_AUTH_FILE` | *filename* | - | Auth database file used for access control (SQLite). If set, enables authentication and access control. Not required if `database-url` is set. See [access control](#access-control). | +| `auth-default-access` | `NTFY_AUTH_DEFAULT_ACCESS` | `read-write`, `read-only`, `write-only`, `deny-all` | `read-write` | Default permissions if no matching entries in the auth database are found. Default is `read-write`. | +| `behind-proxy` | `NTFY_BEHIND_PROXY` | *bool* | false | If set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting) | +| `proxy-forwarded-header` | `NTFY_PROXY_FORWARDED_HEADER` | *string* | `X-Forwarded-For` | Use specified header to determine visitor IP address (for rate limiting) | +| `proxy-trusted-hosts` | `NTFY_PROXY_TRUSTED_HOSTS` | *comma-separated host/IP/CIDR list* | - | Comma-separated list of trusted IP addresses, hosts, or CIDRs to remove from forwarded header | +| `attachment-cache-dir` | `NTFY_ATTACHMENT_CACHE_DIR` | *directory or S3 URL* | - | Cache directory for attached files, or S3 URL for object storage (format: `s3://KEY:SECRET@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]`). | +| `attachment-total-size-limit` | `NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 5G | Limit of the on-disk attachment cache directory. If the limits is exceeded, new attachments will be rejected. | +| `attachment-file-size-limit` | `NTFY_ATTACHMENT_FILE_SIZE_LIMIT` | *size* | 15M | Per-file attachment size limit (e.g. 300k, 2M, 100M). Larger attachment will be rejected. | +| `attachment-expiry-duration` | `NTFY_ATTACHMENT_EXPIRY_DURATION` | *duration* | 3h | Duration after which uploaded attachments will be deleted (e.g. 3h, 20h). Strongly affects `visitor-attachment-total-size-limit`. | +| `smtp-sender-addr` | `NTFY_SMTP_SENDER_ADDR` | `host:port` | - | SMTP server address to allow email sending | +| `smtp-sender-user` | `NTFY_SMTP_SENDER_USER` | *string* | - | SMTP user; only used if e-mail sending is enabled | +| `smtp-sender-pass` | `NTFY_SMTP_SENDER_PASS` | *string* | - | SMTP password; only used if e-mail sending is enabled | +| `smtp-sender-from` | `NTFY_SMTP_SENDER_FROM` | *e-mail address* | - | SMTP sender e-mail address; only used if e-mail sending is enabled | +| `smtp-server-listen` | `NTFY_SMTP_SERVER_LISTEN` | `[ip]:port` | - | Defines the IP address and port the SMTP server will listen on, e.g. `:25` or `1.2.3.4:25` | +| `smtp-server-domain` | `NTFY_SMTP_SERVER_DOMAIN` | *domain name* | - | SMTP server e-mail domain, e.g. `ntfy.sh` | +| `smtp-server-addr-prefix` | `NTFY_SMTP_SERVER_ADDR_PREFIX` | *string* | - | Optional prefix for the e-mail addresses to prevent spam, e.g. `ntfy-` | +| `twilio-account` | `NTFY_TWILIO_ACCOUNT` | *string* | - | Twilio account SID, e.g. AC12345beefbeef67890beefbeef122586 | +| `twilio-auth-token` | `NTFY_TWILIO_AUTH_TOKEN` | *string* | - | Twilio auth token, e.g. affebeef258625862586258625862586 | +| `twilio-phone-number` | `NTFY_TWILIO_PHONE_NUMBER` | *string* | - | Twilio outgoing phone number, e.g. +18775132586 | +| `twilio-verify-service` | `NTFY_TWILIO_VERIFY_SERVICE` | *string* | - | Twilio Verify service SID, e.g. VA12345beefbeef67890beefbeef122586 | +| `keepalive-interval` | `NTFY_KEEPALIVE_INTERVAL` | *duration* | 45s | Interval in which keepalive messages are sent to the client. This is to prevent intermediaries closing the connection for inactivity. Note that the Android app has a hardcoded timeout at 77s, so it should be less than that. | +| `manager-interval` | `NTFY_MANAGER_INTERVAL` | *duration* | 1m | Interval in which the manager prunes old messages, deletes topics and prints the stats. | +| `message-size-limit` | `NTFY_MESSAGE_SIZE_LIMIT` | *size* | 4K | The size limit for the message body. Please note that this is largely untested, and that FCM/APNS have limits around 4KB. If you increase this size limit, FCM and APNS will NOT work for large messages. | +| `message-delay-limit` | `NTFY_MESSAGE_DELAY_LIMIT` | *duration* | 3d | Amount of time a message can be [scheduled](publish.md#scheduled-delivery) into the future when using the `Delay` header | +| `global-topic-limit` | `NTFY_GLOBAL_TOPIC_LIMIT` | *number* | 15,000 | Rate limiting: Total number of topics before the server rejects new topics. | +| `upstream-base-url` | `NTFY_UPSTREAM_BASE_URL` | *URL* | `https://ntfy.sh` | Forward poll request to an upstream server, this is needed for iOS push notifications for self-hosted servers | +| `upstream-access-token` | `NTFY_UPSTREAM_ACCESS_TOKEN` | *string* | `tk_zyYLYj...` | Access token to use for the upstream server; needed only if upstream rate limits are exceeded or upstream server requires auth | +| `visitor-attachment-total-size-limit` | `NTFY_VISITOR_ATTACHMENT_TOTAL_SIZE_LIMIT` | *size* | 100M | Rate limiting: Total storage limit used for attachments per visitor, for all attachments combined. Storage is freed after attachments expire. See `attachment-expiry-duration`. | +| `visitor-attachment-daily-bandwidth-limit` | `NTFY_VISITOR_ATTACHMENT_DAILY_BANDWIDTH_LIMIT` | *size* | 500M | Rate limiting: Total daily attachment download/upload traffic limit per visitor. This is to protect your bandwidth costs from exploding. | +| `visitor-email-limit-burst` | `NTFY_VISITOR_EMAIL_LIMIT_BURST` | *number* | 16 | Rate limiting:Initial limit of e-mails per visitor | +| `visitor-email-limit-replenish` | `NTFY_VISITOR_EMAIL_LIMIT_REPLENISH` | *duration* | 1h | Rate limiting: Strongly related to `visitor-email-limit-burst`: The rate at which the bucket is refilled | +| `visitor-message-daily-limit` | `NTFY_VISITOR_MESSAGE_DAILY_LIMIT` | *number* | - | Rate limiting: Allowed number of messages per day per visitor, reset every day at midnight (UTC). By default, this value is unset. | +| `visitor-request-limit-burst` | `NTFY_VISITOR_REQUEST_LIMIT_BURST` | *number* | 60 | Rate limiting: Allowed GET/PUT/POST requests per second, per visitor. This setting is the initial bucket of requests each visitor has | +| `visitor-request-limit-replenish` | `NTFY_VISITOR_REQUEST_LIMIT_REPLENISH` | *duration* | 5s | Rate limiting: Strongly related to `visitor-request-limit-burst`: The rate at which the bucket is refilled | +| `visitor-request-limit-exempt-hosts` | `NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS` | *comma-separated host/IP/CIDR list* | - | Rate limiting: List of hostnames and IPs to be exempt from request rate limiting | +| `visitor-subscription-limit` | `NTFY_VISITOR_SUBSCRIPTION_LIMIT` | *number* | 30 | Rate limiting: Number of subscriptions per visitor (IP address) | +| `visitor-subscriber-rate-limiting` | `NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING` | *bool* | `false` | Rate limiting: Enables subscriber-based rate limiting | +| `visitor-prefix-bits-ipv4` | `NTFY_VISITOR_PREFIX_BITS_IPV4` | *number* | 32 | Rate limiting: Number of bits to use for IPv4 visitor prefix, e.g. 24 for /24 | +| `visitor-prefix-bits-ipv6` | `NTFY_VISITOR_PREFIX_BITS_IPV6` | *number* | 64 | Rate limiting: Number of bits to use for IPv6 visitor prefix, e.g. 48 for /48 | +| `web-root` | `NTFY_WEB_ROOT` | *path*, e.g. `/` or `/app`, or `disable` | `/` | Sets root of the web app (e.g. /, or /app), or disables it entirely (disable) | +| `enable-signup` | `NTFY_ENABLE_SIGNUP` | *boolean* (`true` or `false`) | `false` | Allows users to sign up via the web app, or API | +| `enable-login` | `NTFY_ENABLE_LOGIN` | *boolean* (`true` or `false`) | `false` | Allows users to log in via the web app, or API | +| `enable-reservations` | `NTFY_ENABLE_RESERVATIONS` | *boolean* (`true` or `false`) | `false` | Allows users to reserve topics (if their tier allows it) | +| `require-login` | `NTFY_REQUIRE_LOGIN` | *boolean* (`true` or `false`) | `false` | All actions via the web app require a login | +| `stripe-secret-key` | `NTFY_STRIPE_SECRET_KEY` | *string* | - | Payments: Key used for the Stripe API communication, this enables payments | +| `stripe-webhook-key` | `NTFY_STRIPE_WEBHOOK_KEY` | *string* | - | Payments: Key required to validate the authenticity of incoming webhooks from Stripe | +| `billing-contact` | `NTFY_BILLING_CONTACT` | *email address* or *website* | - | Payments: Email or website displayed in Upgrade dialog as a billing contact | +| `web-push-public-key` | `NTFY_WEB_PUSH_PUBLIC_KEY` | *string* | - | Web Push: Public Key. Run `ntfy webpush keys` to generate | +| `web-push-private-key` | `NTFY_WEB_PUSH_PRIVATE_KEY` | *string* | - | Web Push: Private Key. Run `ntfy webpush keys` to generate | +| `web-push-file` | `NTFY_WEB_PUSH_FILE` | *string* | - | Web Push: Database file that stores subscriptions | +| `web-push-email-address` | `NTFY_WEB_PUSH_EMAIL_ADDRESS` | *string* | - | Web Push: Sender email address | +| `web-push-startup-queries` | `NTFY_WEB_PUSH_STARTUP_QUERIES` | *string* | - | Web Push: SQL queries to run against subscription database at startup | +| `web-push-expiry-duration` | `NTFY_WEB_PUSH_EXPIRY_DURATION` | *duration* | 60d | Web Push: Duration after which a subscription is considered stale and will be deleted. This is to prevent stale subscriptions. | +| `web-push-expiry-warning-duration` | `NTFY_WEB_PUSH_EXPIRY_WARNING_DURATION` | *duration* | 55d | Web Push: Duration after which a warning is sent to subscribers that their subscription will expire soon. This is to prevent stale subscriptions. | +| `log-format` | `NTFY_LOG_FORMAT` | *string* | `text` | Defines the output format, can be text or json | +| `log-file` | `NTFY_LOG_FILE` | *string* | - | Defines the filename to write logs to. If this is not set, ntfy logs to stderr | +| `log-level` | `NTFY_LOG_LEVEL` | *string* | `info` | Defines the default log level, can be one of trace, debug, info, warn or error | The format for a *duration* is: `(smhd)`, e.g. 30s, 20m, 1h or 3d. The format for a *size* is: `(GMK)`, e.g. 1G, 200M or 4000k. @@ -2249,7 +2259,7 @@ OPTIONS: --auth-file value, --auth_file value, -H value auth database file used for access control [$NTFY_AUTH_FILE] --auth-startup-queries value, --auth_startup_queries value queries run when the auth database is initialized [$NTFY_AUTH_STARTUP_QUERIES] --auth-default-access value, --auth_default_access value, -p value default permissions if no matching entries in the auth database are found (default: "read-write") [$NTFY_AUTH_DEFAULT_ACCESS] - --attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files, or S3 URL (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION) [$NTFY_ATTACHMENT_CACHE_DIR] + --attachment-cache-dir value, --attachment_cache_dir value cache directory for attached files, or S3 URL (s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT]) [$NTFY_ATTACHMENT_CACHE_DIR] --attachment-total-size-limit value, --attachment_total_size_limit value, -A value limit of the on-disk attachment cache (default: "5G") [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT] --attachment-file-size-limit value, --attachment_file_size_limit value, -Y value per-file attachment size limit (e.g. 300k, 2M, 100M) (default: "15M") [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT] --attachment-expiry-duration value, --attachment_expiry_duration value, -X value duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: "3h") [$NTFY_ATTACHMENT_EXPIRY_DURATION] diff --git a/docs/releases.md b/docs/releases.md index b16608c6..afdb8065 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1802,7 +1802,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release **Features:** -* Add S3-compatible object storage as an alternative attachment backend via `attachment-cache-dir` config option +* Add S3-compatible object storage as an alternative [attachment](config.md#attachments) backend via `attachment-cache-dir` config option **Bug fixes + maintenance:** From ef051afc09a3dc0ee965bd22c1a1d4ad0105881f Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 16:06:57 -0400 Subject: [PATCH 26/32] Update base-url in examples --- docs/config.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/config.md b/docs/config.md index de534241..b6ae3009 100644 --- a/docs/config.md +++ b/docs/config.md @@ -517,13 +517,13 @@ Here's an example config using the local filesystem for attachment storage: === "/etc/ntfy/server.yml (minimal)" ``` yaml - base-url: "https://ntfy.sh" + base-url: "https://ntfy.example.com" attachment-cache-dir: "/var/cache/ntfy/attachments" ``` === "/etc/ntfy/server.yml (all options)" ``` yaml - base-url: "https://ntfy.sh" + base-url: "https://ntfy.example.com" attachment-cache-dir: "/var/cache/ntfy/attachments" attachment-total-size-limit: "5G" attachment-file-size-limit: "15M" @@ -546,19 +546,19 @@ When `endpoint` is specified, path-style addressing is enabled automatically (us === "/etc/ntfy/server.yml (AWS S3)" ``` yaml - base-url: "https://ntfy.sh" + base-url: "https://ntfy.example.com" attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=us-east-1" ``` === "/etc/ntfy/server.yml (DigitalOcean Spaces)" ``` yaml - base-url: "https://ntfy.sh" + base-url: "https://ntfy.example.com" attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=nyc3&endpoint=https://nyc3.digitaloceanspaces.com" ``` === "/etc/ntfy/server.yml (custom endpoint)" ``` yaml - base-url: "https://ntfy.sh" + base-url: "https://ntfy.example.com" attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=us-east-1&endpoint=https://s3.example.com" ``` From a04128520d1e9fbd9d4cd35afde329d7eab3b1ed Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 16:17:17 -0400 Subject: [PATCH 27/32] Run S3 tests in CI --- .github/workflows/release.yaml | 1 + .github/workflows/test.yaml | 1 + server/server_test.go | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4ebb9d56..3c959bb6 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -22,6 +22,7 @@ jobs: --health-retries 5 env: NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable" + NTFY_TEST_S3_URL: ${{ secrets.NTFY_TEST_S3_URL }} steps: - name: Checkout code uses: actions/checkout@v3 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 94f08fd9..803ca01f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,6 +19,7 @@ jobs: --health-retries 5 env: NTFY_TEST_DATABASE_URL: "postgres://ntfy:ntfy@localhost:5432/ntfy_test?sslmode=disable" + NTFY_TEST_S3_URL: ${{ secrets.NTFY_TEST_S3_URL }} steps: - name: Checkout code uses: actions/checkout@v3 diff --git a/server/server_test.go b/server/server_test.go index 449b6006..44b9ac94 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2145,7 +2145,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) { require.Equal(t, "myfile.txt", msg.Attachment.Name) require.Equal(t, "text/plain; charset=utf-8", msg.Attachment.Type) require.Equal(t, int64(21), msg.Attachment.Size) - require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix()) + require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix()-1) require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/") require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) From 4d07897d2dbaa159754398a040e0ee3d6d571610 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 16:20:45 -0400 Subject: [PATCH 28/32] RWMutex --- attachment/store.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/attachment/store.go b/attachment/store.go index 3666bdd7..f250d106 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -34,7 +34,7 @@ type Store struct { sizes map[string]int64 // File ID -> size, for subtracting on Remove localIDs func() ([]string, error) // Returns file IDs that should exist locally, used for sync() closeChan chan struct{} - mu sync.Mutex // Protects size and sizes + mu sync.RWMutex // Protects size and sizes } // NewFileStore creates a new file-system backed attachment cache @@ -191,15 +191,15 @@ func (c *Store) sync() error { // Size returns the current total size of all attachments func (c *Store) Size() int64 { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.size } // Remaining returns the remaining capacity for attachments func (c *Store) Remaining() int64 { - c.mu.Lock() - defer c.mu.Unlock() + c.mu.RLock() + defer c.mu.RUnlock() remaining := c.limit - c.size if remaining < 0 { return 0 From 69cc80ec1e86917095f19cca811470782739c6ea Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 20:52:25 -0400 Subject: [PATCH 29/32] Add comments about AWS S3 --- attachment/store.go | 5 +++-- docs/config.md | 28 ++++++++++++++++++++++++++++ docs/releases.md | 2 +- s3/client.go | 32 +++++++++++++++++++++++++++++++- 4 files changed, 63 insertions(+), 4 deletions(-) diff --git a/attachment/store.go b/attachment/store.go index f250d106..d70ea2ab 100644 --- a/attachment/store.go +++ b/attachment/store.go @@ -157,7 +157,7 @@ func (c *Store) sync() error { // than the grace period to account for races, and skipping objects with invalid IDs. cutoff := time.Now().Add(-orphanGracePeriod) var orphanIDs []string - var size int64 + var count, size int64 sizes := make(map[string]int64, len(remoteObjects)) for _, obj := range remoteObjects { if !fileIDRegex.MatchString(obj.ID) { @@ -166,11 +166,12 @@ func (c *Store) sync() error { if _, ok := localIDMap[obj.ID]; !ok && obj.LastModified.Before(cutoff) { orphanIDs = append(orphanIDs, obj.ID) } else { + count++ size += obj.Size sizes[obj.ID] = obj.Size } } - log.Tag(tagStore).Debug("Attachment store updated: %d attachment(s), %s", len(localIDs), util.FormatSizeHuman(size)) + log.Tag(tagStore).Debug("Attachment store updated: %d attachment(s), %s", count, util.FormatSizeHuman(size)) c.mu.Lock() c.size = size c.sizes = sizes diff --git a/docs/config.md b/docs/config.md index b6ae3009..3456b661 100644 --- a/docs/config.md +++ b/docs/config.md @@ -565,6 +565,34 @@ When `endpoint` is specified, path-style addressing is enabled automatically (us Note that the access key and secret key may have to be URL encoded. For instance, a secret key `YmxhY+mxhYmxhC` (note the `+`) should be encoded as `YmxhY%2BmxhYmxhC` (note the `%2B`), so the URL would be `s3://ACCESS_KEY:YmxhY%2BmxhYmxhC@my-bucket/attachments...`. +For **AWS S3**, the IAM user needs the following permissions on the bucket: + +``` json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:ListBucket", + "s3:ListBucketMultipartUploads" + ], + "Resource": "arn:aws:s3:::BUCKET_NAME" + }, + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:AbortMultipartUpload" + ], + "Resource": "arn:aws:s3:::BUCKET_NAME/*" + } + ] +} +``` + ## Access control By default, the ntfy server is open for everyone, meaning **everyone can read and write to any topic** (this is how ntfy.sh is configured). To restrict access to your own server, you can optionally configure authentication and authorization. diff --git a/docs/releases.md b/docs/releases.md index afdb8065..975c7536 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1802,7 +1802,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release **Features:** -* Add S3-compatible object storage as an alternative [attachment](config.md#attachments) backend via `attachment-cache-dir` config option +* Add S3-compatible object storage as an alternative [attachment store](config.md#attachments) via `attachment-cache-dir` config option **Bug fixes + maintenance:** diff --git a/s3/client.go b/s3/client.go index e06ff5c9..8e84bbc5 100644 --- a/s3/client.go +++ b/s3/client.go @@ -25,6 +25,32 @@ const ( // and ListObjectsV2 operations using AWS Signature V4 signing. The bucket and optional key prefix // are fixed at construction time. All operations target the same bucket and prefix. // +// The following IAM policy is required for AWS S3: +// +// { +// "Version": "2012-10-17", +// "Statement": [ +// { +// "Effect": "Allow", +// "Action": [ +// "s3:ListBucket", +// "s3:ListBucketMultipartUploads" +// ], +// "Resource": "arn:aws:s3:::BUCKET_NAME" +// }, +// { +// "Effect": "Allow", +// "Action": [ +// "s3:GetObject", +// "s3:PutObject", +// "s3:DeleteObject", +// "s3:AbortMultipartUpload" +// ], +// "Resource": "arn:aws:s3:::BUCKET_NAME/*" +// } +// ] +// } +// // Fields must not be modified after the Client is passed to any method or goroutine. type Client struct { config *Config @@ -149,7 +175,11 @@ func (c *Client) ListObjectsV2(ctx context.Context) ([]*Object, error) { // listObjectsV2 performs a single ListObjectsV2 request using the client's configured prefix. func (c *Client) listObjectsV2(ctx context.Context, continuationToken string) (*listObjectsV2Result, error) { - log.Tag(tagS3Client).Debug("Listing remote objects with continuation token '%s'", continuationToken) + if continuationToken == "" { + log.Tag(tagS3Client).Debug("Listing remote objects") + } else { + log.Tag(tagS3Client).Debug("Listing remote objects, continuing with token '%s'", continuationToken) + } query := url.Values{"list-type": {"2"}} if prefix := c.config.ListPrefix(); prefix != "" { query.Set("prefix", prefix) From 233ec0973d2569ab7680b2eca69c8c43e5097e4e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 21:01:33 -0400 Subject: [PATCH 30/32] bump wait time --- s3/client_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/s3/client_test.go b/s3/client_test.go index f4b85089..23cde72c 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -383,7 +383,7 @@ func newTestClient(t *testing.T) *Client { func deleteAllObjects(t *testing.T, client *Client) { t.Helper() - for i := 0; i < 20; i++ { + for i := 0; i < 60; i++ { objects, err := client.ListObjectsV2(context.Background()) require.Nil(t, err) if len(objects) == 0 { @@ -394,20 +394,20 @@ func deleteAllObjects(t *testing.T, client *Client) { keys[j] = obj.Key } require.Nil(t, client.DeleteObjects(context.Background(), keys)) - time.Sleep(200 * time.Millisecond) + time.Sleep(500 * time.Millisecond) } t.Fatal("timed out waiting for bucket to be empty") } func waitForCount(t *testing.T, client *Client, expected int) { t.Helper() - for i := 0; i < 20; i++ { + for i := 0; i < 60; i++ { objects, err := client.ListObjectsV2(context.Background()) require.Nil(t, err) if len(objects) == expected { return } - time.Sleep(200 * time.Millisecond) + time.Sleep(500 * time.Millisecond) } objects, _ := client.ListObjectsV2(context.Background()) t.Fatalf("timed out waiting for %d objects, got %d", expected, len(objects)) From e87a3e62feab6ae4fbe1399128406d5873bc51fc Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 21:11:21 -0400 Subject: [PATCH 31/32] Fix workflows to not double run --- .github/workflows/build.yaml | 5 ++++- .github/workflows/test.yaml | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 1baa82c2..ca44e4b6 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -1,5 +1,8 @@ name: build -on: [ push, pull_request ] +on: + push: + branches: [ main ] + pull_request: jobs: build: runs-on: ubuntu-latest diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 803ca01f..4d6bbbdb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,5 +1,8 @@ name: test -on: [ push, pull_request ] +on: + push: + branches: [ main ] + pull_request: jobs: test: runs-on: ubuntu-latest From e6192c94bd58924e1d071fd00b3cf28ea14ce43f Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 22 Mar 2026 21:24:43 -0400 Subject: [PATCH 32/32] Docs updates --- docs/config.md | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/docs/config.md b/docs/config.md index 3456b661..c9e6687d 100644 --- a/docs/config.md +++ b/docs/config.md @@ -533,22 +533,15 @@ Here's an example config using the local filesystem for attachment storage: ``` ### S3 storage -As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. AWS S3, -DigitalOcean Spaces). This is useful for HA/cloud deployments where you don't want to rely on local disk storage. - -To use S3, set `attachment-cache-dir` to an S3 URL with the following format: +As an alternative to the local filesystem, you can store attachments in an S3-compatible object store (e.g. [AWS S3](https://aws.amazon.com/s3/), +[DigitalOcean Spaces](https://www.digitalocean.com/products/spaces)). This is useful for HA/cloud deployments where you don't want to rely on local disk storage. +To use an S3-compatible storage for attachments, set `attachment-cache-dir` to an S3 URL with the following format: ``` s3://ACCESS_KEY:SECRET_KEY@BUCKET[/PREFIX]?region=REGION[&endpoint=ENDPOINT] ``` -When `endpoint` is specified, path-style addressing is enabled automatically (useful for S3-compatible stores like DigitalOcean Spaces). - -=== "/etc/ntfy/server.yml (AWS S3)" - ``` yaml - base-url: "https://ntfy.example.com" - attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=us-east-1" - ``` +Here are a few examples: === "/etc/ntfy/server.yml (DigitalOcean Spaces)" ``` yaml @@ -556,6 +549,12 @@ When `endpoint` is specified, path-style addressing is enabled automatically (us attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=nyc3&endpoint=https://nyc3.digitaloceanspaces.com" ``` +=== "/etc/ntfy/server.yml (AWS S3)" + ``` yaml + base-url: "https://ntfy.example.com" + attachment-cache-dir: "s3://ACCESS_KEY:SECRET_KEY@my-bucket/attachments?region=us-east-1" + ``` + === "/etc/ntfy/server.yml (custom endpoint)" ``` yaml base-url: "https://ntfy.example.com" @@ -565,7 +564,12 @@ When `endpoint` is specified, path-style addressing is enabled automatically (us Note that the access key and secret key may have to be URL encoded. For instance, a secret key `YmxhY+mxhYmxhC` (note the `+`) should be encoded as `YmxhY%2BmxhYmxhC` (note the `%2B`), so the URL would be `s3://ACCESS_KEY:YmxhY%2BmxhYmxhC@my-bucket/attachments...`. -For **AWS S3**, the IAM user needs the following permissions on the bucket: +!!! info + ntfy.sh is hosted and sponsored by DigitalOcean. I can highly recommend their public cloud offering. It's been rock solid + for 4 years. They offer an S3-compatible storage for $5/month and 250 GB of storage, with 1 TiB of bandwidth. + Also, if you **use [this referral link](https://m.do.co/c/442b929528db), you can get $200 credit**. + +For AWS S3, the IAM user needs the following permissions on the bucket: ``` json {