Move stuff around

This commit is contained in:
binwiederhier 2026-03-19 21:46:52 -04:00
parent a1b403d23c
commit d86e20173c
8 changed files with 396 additions and 396 deletions

View file

@ -3,7 +3,6 @@ package attachment
import (
"context"
"io"
"strings"
"time"
"heckel.io/ntfy/v2/log"
@ -38,15 +37,10 @@ func (b *s3Backend) List() ([]object, error) {
if err != nil {
return nil, err
}
prefix := b.client.Prefix
result := make([]object, 0, len(objects))
for _, obj := range objects {
id := obj.Key
if prefix != "" {
id = strings.TrimPrefix(id, prefix+"/")
}
result = append(result, object{
ID: id,
ID: obj.Key,
Size: obj.Size,
LastModified: obj.LastModified,
})

View file

@ -197,7 +197,7 @@ func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) {
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store {
t.Helper()
host := strings.TrimPrefix(server.URL, "https://")
backend := newS3Backend(&s3.Client{
backend := newS3Backend(s3.New(&s3.Config{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
@ -206,7 +206,7 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
})
}))
cache, err := newStore(backend, totalSizeLimit, nil)
require.Nil(t, err)
t.Cleanup(func() { cache.Close() })

View file

@ -8,14 +8,12 @@ import (
"context"
"crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers
"encoding/base64"
"encoding/hex"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
@ -33,26 +31,19 @@ const (
//
// Fields must not be modified after the Client is passed to any method or goroutine.
type Client struct {
AccessKey string // AWS access key ID
SecretKey string // AWS secret access key
Region string // e.g. "us-east-1"
Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com"
Bucket string // S3 bucket name
Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically
PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style
HTTPClient *http.Client // if nil, http.DefaultClient is used
config *Config
http *http.Client
}
// New creates a new S3 client from the given Config.
func New(config *Config) *Client {
httpClient := config.HTTPClient
if httpClient == nil {
httpClient = http.DefaultClient
}
return &Client{
AccessKey: config.AccessKey,
SecretKey: config.SecretKey,
Region: config.Region,
Endpoint: config.Endpoint,
Bucket: config.Bucket,
Prefix: config.Prefix,
PathStyle: config.PathStyle,
config: config,
http: httpClient,
}
}
@ -87,7 +78,7 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int6
return nil, 0, fmt.Errorf("s3: GetObject request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
resp, err := c.http.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("s3: GetObject: %w", err)
}
@ -116,35 +107,18 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
}
body.WriteString("</Delete>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
// Content-MD5 is required by the S3 protocol for DeleteObjects requests.
md5Sum := md5.Sum(bodyBytes) //nolint:gosec
contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:])
reqURL := c.bucketURL() + "?delete="
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
respBody, err := c.doWithBodyAndHeaders(ctx, http.MethodPost, c.config.BucketURL()+"?delete=", bodyBytes,
map[string]string{"Content-MD5": contentMD5}, "DeleteObjects")
if err != nil {
return fmt.Errorf("s3: DeleteObjects request: %w", err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
req.Header.Set("Content-MD5", contentMD5)
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return fmt.Errorf("s3: DeleteObjects: %w", err)
}
defer resp.Body.Close()
if !isHTTPSuccess(resp) {
return parseError(resp)
return err
}
// S3 may return HTTP 200 with per-key errors in the response body
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: DeleteObjects read response: %w", err)
}
var result deleteResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil // If we can't parse, assume success (Quiet mode returns empty body on success)
@ -159,9 +133,9 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
return nil
}
// ListObjects performs a single ListObjectsV2 request using the client's configured prefix.
// listObjects performs a single ListObjectsV2 request using the client's configured prefix.
// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000).
func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) {
func (c *Client) listObjects(ctx context.Context, continuationToken string, maxKeys int) (*listResult, error) {
log.Tag(tagS3Client).Debug("ListObjects continuation=%s maxKeys=%d", continuationToken, maxKeys)
query := url.Values{"list-type": {"2"}}
if prefix := c.prefixForList(); prefix != "" {
@ -173,22 +147,9 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK
if maxKeys > 0 {
query.Set("max-keys", strconv.Itoa(maxKeys))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListObjects")
if err != nil {
return nil, fmt.Errorf("s3: ListObjects request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("s3: ListObjects: %w", err)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("s3: ListObjects read: %w", err)
}
if !isHTTPSuccess(resp) {
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
return nil, err
}
var result listObjectsV2Response
if err := xml.Unmarshal(respBody, &result); err != nil {
@ -206,7 +167,7 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK
LastModified: lastModified,
}
}
return &ListResult{
return &listResult{
Objects: objects,
IsTruncated: result.IsTruncated,
NextContinuationToken: result.NextContinuationToken,
@ -214,16 +175,21 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK
}
// ListAllObjects returns all objects under the client's configured prefix by paginating through
// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve.
// ListObjectsV2 results automatically. Keys in the returned objects have the prefix stripped,
// so they match the keys used with PutObject/GetObject/DeleteObjects. It stops after 10,000
// pages as a safety valve.
func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
var all []Object
var token string
for page := 0; page < maxPages; page++ {
result, err := c.ListObjects(ctx, token, 0)
result, err := c.listObjects(ctx, token, 0)
if err != nil {
return nil, err
}
all = append(all, result.Objects...)
for _, obj := range result.Objects {
obj.Key = c.stripPrefix(obj.Key)
all = append(all, obj)
}
if !result.IsTruncated {
return all, nil
}
@ -232,77 +198,6 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
}
// ListMultipartUploads returns in-progress multipart uploads for the client's prefix.
// It paginates automatically, stopping after 10,000 pages as a safety valve.
func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) {
var all []MultipartUpload
var keyMarker, uploadIDMarker string
for page := 0; page < maxPages; page++ {
query := url.Values{"uploads": {""}}
if prefix := c.prefixForList(); prefix != "" {
query.Set("prefix", prefix)
}
if keyMarker != "" {
query.Set("key-marker", keyMarker)
query.Set("upload-id-marker", uploadIDMarker)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
if err != nil {
return nil, fmt.Errorf("s3: ListMultipartUploads request: %w", err)
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return nil, fmt.Errorf("s3: ListMultipartUploads: %w", err)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
resp.Body.Close()
if err != nil {
return nil, fmt.Errorf("s3: ListMultipartUploads read: %w", err)
}
if !isHTTPSuccess(resp) {
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
}
var result listMultipartUploadsResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err)
}
for _, u := range result.Uploads {
var initiated time.Time
if u.Initiated != "" {
initiated, _ = time.Parse(time.RFC3339, u.Initiated)
}
all = append(all, MultipartUpload{
Key: u.Key,
UploadID: u.UploadID,
Initiated: initiated,
})
}
if !result.IsTruncated {
return all, nil
}
keyMarker = result.NextKeyMarker
uploadIDMarker = result.NextUploadIDMarker
}
return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages)
}
// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated
// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads.
func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error {
uploads, err := c.ListMultipartUploads(ctx)
if err != nil {
return err
}
for _, u := range uploads {
if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) {
log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated)
c.abortMultipartUpload(ctx, u.Key, u.UploadID)
}
}
return nil
}
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
fullKey := c.objectKey(key)
@ -312,235 +207,80 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size
}
req.ContentLength = size
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("s3: PutObject: %w", err)
}
defer resp.Body.Close()
resp.Body.Close()
if !isHTTPSuccess(resp) {
return parseError(resp)
}
return nil
}
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
// chunks, uploading each as a separate part. This allows uploading without knowing the total
// body size in advance.
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
fullKey := c.objectKey(key)
// Step 1: Initiate multipart upload
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
// do creates a request, signs it with an empty payload, executes it, reads the response body,
// and checks for errors. It is used for bodiless GET/POST requests.
func (c *Client) do(ctx context.Context, method, reqURL string, body io.Reader, op string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
if err != nil {
return err
return nil, fmt.Errorf("s3: %s request: %w", op, err)
}
// Step 2: Upload parts
var parts []completedPart
buf := make([]byte, partSize)
partNumber := 1
for {
n, err := io.ReadFull(body, buf)
if n > 0 {
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
if uploadErr != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return uploadErr
}
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
partNumber++
}
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
if err != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return fmt.Errorf("s3: PutObject read: %w", err)
}
if body == nil {
req.ContentLength = 0
}
// Step 3: Complete multipart upload
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
}
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
reqURL := c.objectURL(fullKey) + "?uploads"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
}
req.ContentLength = 0
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
resp, err := c.http.Do(req)
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
}
defer resp.Body.Close()
if !isHTTPSuccess(resp) {
return "", parseError(resp)
return nil, fmt.Errorf("s3: %s: %w", op, err)
}
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
resp.Body.Close()
if err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
return nil, fmt.Errorf("s3: %s read: %w", op, err)
}
var result initiateMultipartUploadResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
}
log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID)
return result.UploadID, nil
}
// uploadPart uploads a single part of a multipart upload and returns the ETag.
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data))
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("s3: UploadPart request: %w", err)
}
req.ContentLength = int64(len(data))
c.signV4(req, unsignedPayload)
resp, err := c.httpClient().Do(req)
if err != nil {
return "", fmt.Errorf("s3: UploadPart: %w", err)
}
defer resp.Body.Close()
if !isHTTPSuccess(resp) {
return "", parseError(resp)
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
}
etag := resp.Header.Get("ETag")
return etag, nil
return respBody, nil
}
// completeMultipartUpload finalizes a multipart upload with the given parts.
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts))
var body bytes.Buffer
body.WriteString("<CompleteMultipartUpload>")
for _, p := range parts {
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
}
body.WriteString("</CompleteMultipartUpload>")
bodyBytes := body.Bytes()
payloadHash := sha256Hex(bodyBytes)
// doWithBody is like do, but sends a body with a computed SHA-256 payload hash and Content-Type: application/xml.
func (c *Client) doWithBody(ctx context.Context, method, reqURL string, bodyBytes []byte, op string) ([]byte, error) {
return c.doWithBodyAndHeaders(ctx, method, reqURL, bodyBytes, nil, op)
}
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
// doWithBodyAndHeaders is like doWithBody, but allows setting additional headers (e.g. Content-MD5).
func (c *Client) doWithBodyAndHeaders(ctx context.Context, method, reqURL string, bodyBytes []byte, headers map[string]string, op string) ([]byte, error) {
payloadHash := sha256Hex(bodyBytes)
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
return nil, fmt.Errorf("s3: %s request: %w", op, err)
}
req.ContentLength = int64(len(bodyBytes))
req.Header.Set("Content-Type", "application/xml")
for k, v := range headers {
req.Header.Set(k, v)
}
c.signV4(req, payloadHash)
resp, err := c.httpClient().Do(req)
resp, err := c.http.Do(req)
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
return nil, fmt.Errorf("s3: %s: %w", op, err)
}
defer resp.Body.Close()
if !isHTTPSuccess(resp) {
return parseError(resp)
}
// Read response body to check for errors (S3 can return 200 with an error body)
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
if err != nil {
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
}
// Check if the response contains an error
var errResp ErrorResponse
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
errResp.StatusCode = resp.StatusCode
return &errResp
}
return nil
}
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID)
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
if err != nil {
return
}
c.signV4(req, emptyPayloadHash)
resp, err := c.httpClient().Do(req)
if err != nil {
return
}
resp.Body.Close()
}
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
func (c *Client) signV4(req *http.Request, payloadHash string) {
now := time.Now().UTC()
datestamp := now.Format("20060102")
amzDate := now.Format("20060102T150405Z")
// Required headers
req.Header.Set("Host", c.hostHeader())
req.Header.Set("X-Amz-Date", amzDate)
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
// Canonical headers (all headers we set, sorted by lowercase key)
signedKeys := make([]string, 0, len(req.Header))
canonHeaders := make(map[string]string, len(req.Header))
for k := range req.Header {
lk := strings.ToLower(k)
signedKeys = append(signedKeys, lk)
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
if err != nil {
return nil, fmt.Errorf("s3: %s read: %w", op, err)
}
sort.Strings(signedKeys)
signedHeadersStr := strings.Join(signedKeys, ";")
var chBuf strings.Builder
for _, k := range signedKeys {
chBuf.WriteString(k)
chBuf.WriteByte(':')
chBuf.WriteString(canonHeaders[k])
chBuf.WriteByte('\n')
if !isHTTPSuccess(resp) {
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
}
// Canonical request
canonicalRequest := strings.Join([]string{
req.Method,
canonicalURI(req.URL),
canonicalQueryString(req.URL.Query()),
chBuf.String(),
signedHeadersStr,
payloadHash,
}, "\n")
// String to sign
credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request"
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
// Signing key
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
[]byte("AWS4"+c.SecretKey), []byte(datestamp)),
[]byte(c.Region)),
[]byte("s3")),
[]byte("aws4_request"))
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
req.Header.Set("Authorization", fmt.Sprintf(
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
c.AccessKey, credentialScope, signedHeadersStr, signature,
))
}
func (c *Client) httpClient() *http.Client {
if c.HTTPClient != nil {
return c.HTTPClient
}
return http.DefaultClient
return respBody, nil
}
// objectKey prepends the configured prefix to the given key.
func (c *Client) objectKey(key string) string {
if c.Prefix != "" {
return c.Prefix + "/" + key
if c.config.Prefix != "" {
return c.config.Prefix + "/" + key
}
return key
}
@ -548,18 +288,19 @@ func (c *Client) objectKey(key string) string {
// prefixForList returns the prefix to use in ListObjectsV2 requests,
// with a trailing slash so that only objects under the prefix directory are returned.
func (c *Client) prefixForList() string {
if c.Prefix != "" {
return c.Prefix + "/"
if c.config.Prefix != "" {
return c.config.Prefix + "/"
}
return ""
}
// bucketURL returns the base URL for bucket-level operations.
func (c *Client) bucketURL() string {
if c.PathStyle {
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
// stripPrefix removes the configured prefix from a key returned by ListObjectsV2,
// so keys match what was passed to PutObject/GetObject/DeleteObjects.
func (c *Client) stripPrefix(key string) string {
if c.config.Prefix != "" {
return strings.TrimPrefix(key, c.config.Prefix+"/")
}
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
return key
}
// objectURL returns the full URL for an object (key should already include the prefix).
@ -569,13 +310,5 @@ func (c *Client) objectURL(key string) string {
for i, seg := range segments {
segments[i] = uriEncode(seg)
}
return c.bucketURL() + "/" + strings.Join(segments, "/")
}
// hostHeader returns the value for the Host header.
func (c *Client) hostHeader() string {
if c.PathStyle {
return c.Endpoint
}
return c.Bucket + "." + c.Endpoint
return c.config.BucketURL() + "/" + strings.Join(segments, "/")
}

68
s3/client_auth.go Normal file
View file

@ -0,0 +1,68 @@
package s3
import (
"encoding/hex"
"fmt"
"net/http"
"sort"
"strings"
"time"
)
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
func (c *Client) signV4(req *http.Request, payloadHash string) {
now := time.Now().UTC()
datestamp := now.Format("20060102")
amzDate := now.Format("20060102T150405Z")
// Required headers
req.Header.Set("Host", c.config.HostHeader())
req.Header.Set("X-Amz-Date", amzDate)
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
// Canonical headers (all headers we set, sorted by lowercase key)
signedKeys := make([]string, 0, len(req.Header))
canonHeaders := make(map[string]string, len(req.Header))
for k := range req.Header {
lk := strings.ToLower(k)
signedKeys = append(signedKeys, lk)
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
}
sort.Strings(signedKeys)
signedHeadersStr := strings.Join(signedKeys, ";")
var chBuf strings.Builder
for _, k := range signedKeys {
chBuf.WriteString(k)
chBuf.WriteByte(':')
chBuf.WriteString(canonHeaders[k])
chBuf.WriteByte('\n')
}
// Canonical request
canonicalRequest := strings.Join([]string{
req.Method,
canonicalURI(req.URL),
canonicalQueryString(req.URL.Query()),
chBuf.String(),
signedHeadersStr,
payloadHash,
}, "\n")
// String to sign
credentialScope := datestamp + "/" + c.config.Region + "/s3/aws4_request"
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
// Signing key
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
[]byte("AWS4"+c.config.SecretKey), []byte(datestamp)),
[]byte(c.config.Region)),
[]byte("s3")),
[]byte("aws4_request"))
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
req.Header.Set("Authorization", fmt.Sprintf(
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
c.config.AccessKey, credentialScope, signedHeadersStr, signature,
))
}

188
s3/client_multipart.go Normal file
View file

@ -0,0 +1,188 @@
package s3
import (
"bytes"
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"heckel.io/ntfy/v2/log"
)
// ListMultipartUploads returns in-progress multipart uploads for the client's prefix.
// It paginates automatically, stopping after 10,000 pages as a safety valve.
func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) {
var all []MultipartUpload
var keyMarker, uploadIDMarker string
for page := 0; page < maxPages; page++ {
query := url.Values{"uploads": {""}}
if prefix := c.prefixForList(); prefix != "" {
query.Set("prefix", prefix)
}
if keyMarker != "" {
query.Set("key-marker", keyMarker)
query.Set("upload-id-marker", uploadIDMarker)
}
respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListMultipartUploads")
if err != nil {
return nil, err
}
var result listMultipartUploadsResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err)
}
for _, u := range result.Uploads {
var initiated time.Time
if u.Initiated != "" {
initiated, _ = time.Parse(time.RFC3339, u.Initiated)
}
all = append(all, MultipartUpload{
Key: u.Key,
UploadID: u.UploadID,
Initiated: initiated,
})
}
if !result.IsTruncated {
return all, nil
}
keyMarker = result.NextKeyMarker
uploadIDMarker = result.NextUploadIDMarker
}
return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages)
}
// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated
// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads.
func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error {
uploads, err := c.ListMultipartUploads(ctx)
if err != nil {
return err
}
for _, u := range uploads {
if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) {
log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated)
c.abortMultipartUpload(ctx, u.Key, u.UploadID)
}
}
return nil
}
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
// chunks, uploading each as a separate part. This allows uploading without knowing the total
// body size in advance.
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
fullKey := c.objectKey(key)
// Step 1: Initiate multipart upload
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
if err != nil {
return err
}
// Step 2: Upload parts
var parts []completedPart
buf := make([]byte, partSize)
partNumber := 1
for {
n, err := io.ReadFull(body, buf)
if n > 0 {
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
if uploadErr != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return uploadErr
}
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
partNumber++
}
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
break
}
if err != nil {
c.abortMultipartUpload(ctx, fullKey, uploadID)
return fmt.Errorf("s3: PutObject read: %w", err)
}
}
// Step 3: Complete multipart upload
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
}
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
respBody, err := c.do(ctx, http.MethodPost, c.objectURL(fullKey)+"?uploads", nil, "InitiateMultipartUpload")
if err != nil {
return "", err
}
var result initiateMultipartUploadResult
if err := xml.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
}
log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID)
return result.UploadID, nil
}
// uploadPart uploads a single part of a multipart upload and returns the ETag.
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data))
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("s3: UploadPart request: %w", err)
}
req.ContentLength = int64(len(data))
c.signV4(req, unsignedPayload)
resp, err := c.http.Do(req)
if err != nil {
return "", fmt.Errorf("s3: UploadPart: %w", err)
}
defer resp.Body.Close()
if !isHTTPSuccess(resp) {
return "", parseError(resp)
}
etag := resp.Header.Get("ETag")
return etag, nil
}
// completeMultipartUpload finalizes a multipart upload with the given parts.
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts))
var body bytes.Buffer
body.WriteString("<CompleteMultipartUpload>")
for _, p := range parts {
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
}
body.WriteString("</CompleteMultipartUpload>")
respBody, err := c.doWithBody(ctx, http.MethodPost,
fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)),
body.Bytes(), "CompleteMultipartUpload")
if err != nil {
return err
}
// Check if the response contains an error (S3 can return 200 with an error body)
var errResp ErrorResponse
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
return &errResp
}
return nil
}
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID)
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
if err != nil {
return
}
c.signV4(req, emptyPayloadHash)
resp, err := c.http.Do(req)
if err != nil {
return
}
resp.Body.Close()
}

View file

@ -269,7 +269,7 @@ func (m *mockS3Server) objectCount() int {
func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
host := strings.TrimPrefix(server.URL, "https://")
return &Client{
return New(&Config{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
@ -278,7 +278,7 @@ func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
Prefix: prefix,
PathStyle: true,
HTTPClient: server.Client(),
}
})
}
// --- URL parsing tests ---
@ -363,49 +363,49 @@ func TestParseURL_EmptyBucket(t *testing.T) {
// --- Unit tests: URL construction ---
func TestClient_BucketURL_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL())
func TestConfig_BucketURL_PathStyle(t *testing.T) {
c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "https://s3.example.com/my-bucket", c.BucketURL())
}
func TestClient_BucketURL_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL())
func TestConfig_BucketURL_VirtualHosted(t *testing.T) {
c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.BucketURL())
}
func TestClient_ObjectURL_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
c := &Client{config: &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}}
require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj"))
}
func TestClient_ObjectURL_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
c := &Client{config: &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}}
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj"))
}
func TestClient_HostHeader_PathStyle(t *testing.T) {
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "s3.example.com", c.hostHeader())
func TestConfig_HostHeader_PathStyle(t *testing.T) {
c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
require.Equal(t, "s3.example.com", c.HostHeader())
}
func TestClient_HostHeader_VirtualHosted(t *testing.T) {
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader())
func TestConfig_HostHeader_VirtualHosted(t *testing.T) {
c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.HostHeader())
}
func TestClient_ObjectKey(t *testing.T) {
c := &Client{Prefix: "attachments"}
c := &Client{config: &Config{Prefix: "attachments"}}
require.Equal(t, "attachments/file123", c.objectKey("file123"))
c2 := &Client{Prefix: ""}
c2 := &Client{config: &Config{Prefix: ""}}
require.Equal(t, "file123", c2.objectKey("file123"))
}
func TestClient_PrefixForList(t *testing.T) {
c := &Client{Prefix: "attachments"}
c := &Client{config: &Config{Prefix: "attachments"}}
require.Equal(t, "attachments/", c.prefixForList())
c2 := &Client{Prefix: ""}
c2 := &Client{config: &Config{Prefix: ""}}
require.Equal(t, "", c2.prefixForList())
}
@ -512,13 +512,13 @@ func TestClient_ListObjects(t *testing.T) {
require.Nil(t, err)
// List with prefix client: should only see 3
result, err := client.ListObjects(ctx, "", 0)
result, err := client.listObjects(ctx, "", 0)
require.Nil(t, err)
require.Len(t, result.Objects, 3)
require.False(t, result.IsTruncated)
// List with no-prefix client: should see all 4
result, err = clientNoPrefix.ListObjects(ctx, "", 0)
result, err = clientNoPrefix.listObjects(ctx, "", 0)
require.Nil(t, err)
require.Len(t, result.Objects, 4)
}
@ -537,20 +537,20 @@ func TestClient_ListObjects_Pagination(t *testing.T) {
}
// List with max-keys=2
result, err := client.ListObjects(ctx, "", 2)
result, err := client.listObjects(ctx, "", 2)
require.Nil(t, err)
require.Len(t, result.Objects, 2)
require.True(t, result.IsTruncated)
require.NotEmpty(t, result.NextContinuationToken)
// Get next page
result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2)
result2, err := client.listObjects(ctx, result.NextContinuationToken, 2)
require.Nil(t, err)
require.Len(t, result2.Objects, 2)
require.True(t, result2.IsTruncated)
// Get last page
result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2)
result3, err := client.listObjects(ctx, result2.NextContinuationToken, 2)
require.Nil(t, err)
require.Len(t, result3.Objects, 1)
require.False(t, result3.IsTruncated)
@ -744,7 +744,7 @@ func TestClient_RealBucket(t *testing.T) {
prefix = "ntfy-s3-test"
}
client := &Client{
client := New(&Config{
AccessKey: accessKey,
SecretKey: secretKey,
Region: region,
@ -752,7 +752,7 @@ func TestClient_RealBucket(t *testing.T) {
Bucket: bucket,
Prefix: prefix,
PathStyle: pathStyle,
}
})
ctx := context.Background()
@ -762,8 +762,7 @@ func TestClient_RealBucket(t *testing.T) {
if len(existing) > 0 {
keys := make([]string, len(existing))
for i, obj := range existing {
// Strip the prefix since DeleteObjects will re-add it
keys[i] = strings.TrimPrefix(obj.Key, prefix+"/")
keys[i] = obj.Key
}
// Batch delete in groups of 1000
for i := 0; i < len(keys); i += 1000 {
@ -807,7 +806,7 @@ func TestClient_RealBucket(t *testing.T) {
t.Run("ListObjects", func(t *testing.T) {
// Use a sub-prefix client for isolation
listClient := &Client{
listClient := New(&Config{
AccessKey: accessKey,
SecretKey: secretKey,
Region: region,
@ -815,7 +814,7 @@ func TestClient_RealBucket(t *testing.T) {
Bucket: bucket,
Prefix: prefix + "/list-test",
PathStyle: pathStyle,
}
})
// Put 10 objects
for i := 0; i < 10; i++ {

View file

@ -2,18 +2,36 @@ package s3
import (
"fmt"
"net/http"
"time"
)
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
type Config struct {
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
PathStyle bool
Bucket string
Prefix string
Region string
AccessKey string
SecretKey string
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
PathStyle bool
Bucket string
Prefix string
Region string
AccessKey string
SecretKey string
HTTPClient *http.Client // if nil, http.DefaultClient is used
}
// bucketURL returns the base URL for bucket-level operations.
func (c *Config) BucketURL() string {
if c.PathStyle {
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
}
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
}
// hostHeader returns the value for the Host header.
func (c *Config) HostHeader() string {
if c.PathStyle {
return c.Endpoint
}
return c.Bucket + "." + c.Endpoint
}
// Object represents an S3 object returned by list operations.
@ -23,8 +41,8 @@ type Object struct {
LastModified time.Time
}
// ListResult holds the response from a ListObjectsV2 call.
type ListResult struct {
// listResult holds the response from a single ListObjectsV2 page.
type listResult struct {
Objects []Object
IsTruncated bool
NextContinuationToken string
@ -78,10 +96,10 @@ type MultipartUpload struct {
// listMultipartUploadsResult is the XML response from S3 ListMultipartUploads
type listMultipartUploadsResult struct {
Uploads []listUpload `xml:"Upload"`
IsTruncated bool `xml:"IsTruncated"`
NextKeyMarker string `xml:"NextKeyMarker"`
NextUploadIDMarker string `xml:"NextUploadIdMarker"`
Uploads []listUpload `xml:"Upload"`
IsTruncated bool `xml:"IsTruncated"`
NextKeyMarker string `xml:"NextKeyMarker"`
NextUploadIDMarker string `xml:"NextUploadIdMarker"`
}
type listUpload struct {

View file

@ -105,13 +105,13 @@ func TestHmacSHA256(t *testing.T) {
}
func TestSignV4_SetsRequiredHeaders(t *testing.T) {
c := &Client{
c := &Client{config: &Config{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: "s3.us-east-1.amazonaws.com",
Bucket: "my-bucket",
}
}}
req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
c.signV4(req, emptyPayloadHash)
@ -131,13 +131,13 @@ func TestSignV4_SetsRequiredHeaders(t *testing.T) {
}
func TestSignV4_UnsignedPayload(t *testing.T) {
c := &Client{
c := &Client{config: &Config{
AccessKey: "AKID",
SecretKey: "SECRET",
Region: "us-east-1",
Endpoint: "s3.us-east-1.amazonaws.com",
Bucket: "my-bucket",
}
}}
req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
c.signV4(req, unsignedPayload)
@ -146,8 +146,8 @@ func TestSignV4_UnsignedPayload(t *testing.T) {
}
func TestSignV4_DifferentRegions(t *testing.T) {
c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}
c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}
c1 := &Client{config: &Config{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}}
c2 := &Client{config: &Config{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}}
req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil)
c1.signV4(req1, emptyPayloadHash)