diff --git a/attachment/backend_s3.go b/attachment/backend_s3.go index 11d2254b..6603bb91 100644 --- a/attachment/backend_s3.go +++ b/attachment/backend_s3.go @@ -3,7 +3,6 @@ package attachment import ( "context" "io" - "strings" "time" "heckel.io/ntfy/v2/log" @@ -38,15 +37,10 @@ func (b *s3Backend) List() ([]object, error) { if err != nil { return nil, err } - prefix := b.client.Prefix result := make([]object, 0, len(objects)) for _, obj := range objects { - id := obj.Key - if prefix != "" { - id = strings.TrimPrefix(id, prefix+"/") - } result = append(result, object{ - ID: id, + ID: obj.Key, Size: obj.Size, LastModified: obj.LastModified, }) diff --git a/attachment/store_s3_test.go b/attachment/store_s3_test.go index 7ae122ae..3ad5a93c 100644 --- a/attachment/store_s3_test.go +++ b/attachment/store_s3_test.go @@ -197,7 +197,7 @@ func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) { func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store { t.Helper() host := strings.TrimPrefix(server.URL, "https://") - backend := newS3Backend(&s3.Client{ + backend := newS3Backend(s3.New(&s3.Config{ AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", @@ -206,7 +206,7 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string Prefix: prefix, PathStyle: true, HTTPClient: server.Client(), - }) + })) cache, err := newStore(backend, totalSizeLimit, nil) require.Nil(t, err) t.Cleanup(func() { cache.Close() }) diff --git a/s3/client.go b/s3/client.go index 41d940f3..5c43af05 100644 --- a/s3/client.go +++ b/s3/client.go @@ -8,14 +8,12 @@ import ( "context" "crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers "encoding/base64" - "encoding/hex" "encoding/xml" "errors" "fmt" "io" "net/http" "net/url" - "sort" "strconv" "strings" "time" @@ -33,26 +31,19 @@ const ( // // Fields must not be modified after the Client is passed to any method or goroutine. type Client struct { - AccessKey string // AWS access key ID - SecretKey string // AWS secret access key - Region string // e.g. "us-east-1" - Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com" - Bucket string // S3 bucket name - Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically - PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style - HTTPClient *http.Client // if nil, http.DefaultClient is used + config *Config + http *http.Client } // New creates a new S3 client from the given Config. func New(config *Config) *Client { + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } return &Client{ - AccessKey: config.AccessKey, - SecretKey: config.SecretKey, - Region: config.Region, - Endpoint: config.Endpoint, - Bucket: config.Bucket, - Prefix: config.Prefix, - PathStyle: config.PathStyle, + config: config, + http: httpClient, } } @@ -87,7 +78,7 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int6 return nil, 0, fmt.Errorf("s3: GetObject request: %w", err) } c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) + resp, err := c.http.Do(req) if err != nil { return nil, 0, fmt.Errorf("s3: GetObject: %w", err) } @@ -116,35 +107,18 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { } body.WriteString("") bodyBytes := body.Bytes() - payloadHash := sha256Hex(bodyBytes) // Content-MD5 is required by the S3 protocol for DeleteObjects requests. md5Sum := md5.Sum(bodyBytes) //nolint:gosec contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:]) - reqURL := c.bucketURL() + "?delete=" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes)) + respBody, err := c.doWithBodyAndHeaders(ctx, http.MethodPost, c.config.BucketURL()+"?delete=", bodyBytes, + map[string]string{"Content-MD5": contentMD5}, "DeleteObjects") if err != nil { - return fmt.Errorf("s3: DeleteObjects request: %w", err) - } - req.ContentLength = int64(len(bodyBytes)) - req.Header.Set("Content-Type", "application/xml") - req.Header.Set("Content-MD5", contentMD5) - c.signV4(req, payloadHash) - resp, err := c.httpClient().Do(req) - if err != nil { - return fmt.Errorf("s3: DeleteObjects: %w", err) - } - defer resp.Body.Close() - if !isHTTPSuccess(resp) { - return parseError(resp) + return err } // S3 may return HTTP 200 with per-key errors in the response body - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - if err != nil { - return fmt.Errorf("s3: DeleteObjects read response: %w", err) - } var result deleteResult if err := xml.Unmarshal(respBody, &result); err != nil { return nil // If we can't parse, assume success (Quiet mode returns empty body on success) @@ -159,9 +133,9 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error { return nil } -// ListObjects performs a single ListObjectsV2 request using the client's configured prefix. +// listObjects performs a single ListObjectsV2 request using the client's configured prefix. // Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000). -func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) { +func (c *Client) listObjects(ctx context.Context, continuationToken string, maxKeys int) (*listResult, error) { log.Tag(tagS3Client).Debug("ListObjects continuation=%s maxKeys=%d", continuationToken, maxKeys) query := url.Values{"list-type": {"2"}} if prefix := c.prefixForList(); prefix != "" { @@ -173,22 +147,9 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK if maxKeys > 0 { query.Set("max-keys", strconv.Itoa(maxKeys)) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil) + respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListObjects") if err != nil { - return nil, fmt.Errorf("s3: ListObjects request: %w", err) - } - c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) - if err != nil { - return nil, fmt.Errorf("s3: ListObjects: %w", err) - } - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("s3: ListObjects read: %w", err) - } - if !isHTTPSuccess(resp) { - return nil, parseErrorFromBytes(resp.StatusCode, respBody) + return nil, err } var result listObjectsV2Response if err := xml.Unmarshal(respBody, &result); err != nil { @@ -206,7 +167,7 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK LastModified: lastModified, } } - return &ListResult{ + return &listResult{ Objects: objects, IsTruncated: result.IsTruncated, NextContinuationToken: result.NextContinuationToken, @@ -214,16 +175,21 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK } // ListAllObjects returns all objects under the client's configured prefix by paginating through -// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve. +// ListObjectsV2 results automatically. Keys in the returned objects have the prefix stripped, +// so they match the keys used with PutObject/GetObject/DeleteObjects. It stops after 10,000 +// pages as a safety valve. func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { var all []Object var token string for page := 0; page < maxPages; page++ { - result, err := c.ListObjects(ctx, token, 0) + result, err := c.listObjects(ctx, token, 0) if err != nil { return nil, err } - all = append(all, result.Objects...) + for _, obj := range result.Objects { + obj.Key = c.stripPrefix(obj.Key) + all = append(all, obj) + } if !result.IsTruncated { return all, nil } @@ -232,77 +198,6 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) { return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages) } -// ListMultipartUploads returns in-progress multipart uploads for the client's prefix. -// It paginates automatically, stopping after 10,000 pages as a safety valve. -func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) { - var all []MultipartUpload - var keyMarker, uploadIDMarker string - for page := 0; page < maxPages; page++ { - query := url.Values{"uploads": {""}} - if prefix := c.prefixForList(); prefix != "" { - query.Set("prefix", prefix) - } - if keyMarker != "" { - query.Set("key-marker", keyMarker) - query.Set("upload-id-marker", uploadIDMarker) - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil) - if err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads request: %w", err) - } - c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) - if err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads: %w", err) - } - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - resp.Body.Close() - if err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads read: %w", err) - } - if !isHTTPSuccess(resp) { - return nil, parseErrorFromBytes(resp.StatusCode, respBody) - } - var result listMultipartUploadsResult - if err := xml.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err) - } - for _, u := range result.Uploads { - var initiated time.Time - if u.Initiated != "" { - initiated, _ = time.Parse(time.RFC3339, u.Initiated) - } - all = append(all, MultipartUpload{ - Key: u.Key, - UploadID: u.UploadID, - Initiated: initiated, - }) - } - if !result.IsTruncated { - return all, nil - } - keyMarker = result.NextKeyMarker - uploadIDMarker = result.NextUploadIDMarker - } - return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages) -} - -// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated -// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads. -func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error { - uploads, err := c.ListMultipartUploads(ctx) - if err != nil { - return err - } - for _, u := range uploads { - if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) { - log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated) - c.abortMultipartUpload(ctx, u.Key, u.UploadID) - } - } - return nil -} - // putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD. func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error { fullKey := c.objectKey(key) @@ -312,235 +207,80 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size } req.ContentLength = size c.signV4(req, unsignedPayload) - resp, err := c.httpClient().Do(req) + resp, err := c.http.Do(req) if err != nil { return fmt.Errorf("s3: PutObject: %w", err) } - defer resp.Body.Close() + resp.Body.Close() if !isHTTPSuccess(resp) { return parseError(resp) } return nil } -// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize -// chunks, uploading each as a separate part. This allows uploading without knowing the total -// body size in advance. -func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error { - fullKey := c.objectKey(key) - - // Step 1: Initiate multipart upload - uploadID, err := c.initiateMultipartUpload(ctx, fullKey) +// do creates a request, signs it with an empty payload, executes it, reads the response body, +// and checks for errors. It is used for bodiless GET/POST requests. +func (c *Client) do(ctx context.Context, method, reqURL string, body io.Reader, op string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, method, reqURL, body) if err != nil { - return err + return nil, fmt.Errorf("s3: %s request: %w", op, err) } - - // Step 2: Upload parts - var parts []completedPart - buf := make([]byte, partSize) - partNumber := 1 - for { - n, err := io.ReadFull(body, buf) - if n > 0 { - etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n]) - if uploadErr != nil { - c.abortMultipartUpload(ctx, fullKey, uploadID) - return uploadErr - } - parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag}) - partNumber++ - } - if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { - break - } - if err != nil { - c.abortMultipartUpload(ctx, fullKey, uploadID) - return fmt.Errorf("s3: PutObject read: %w", err) - } + if body == nil { + req.ContentLength = 0 } - - // Step 3: Complete multipart upload - return c.completeMultipartUpload(ctx, fullKey, uploadID, parts) -} - -// initiateMultipartUpload starts a new multipart upload and returns the upload ID. -func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) { - reqURL := c.objectURL(fullKey) + "?uploads" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil) - if err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err) - } - req.ContentLength = 0 c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) + resp, err := c.http.Do(req) if err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err) - } - defer resp.Body.Close() - if !isHTTPSuccess(resp) { - return "", parseError(resp) + return nil, fmt.Errorf("s3: %s: %w", op, err) } respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + resp.Body.Close() if err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err) + return nil, fmt.Errorf("s3: %s read: %w", op, err) } - var result initiateMultipartUploadResult - if err := xml.Unmarshal(respBody, &result); err != nil { - return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) - } - log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID) - return result.UploadID, nil -} - -// uploadPart uploads a single part of a multipart upload and returns the ETag. -func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) { - log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data)) - reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID)) - req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) - if err != nil { - return "", fmt.Errorf("s3: UploadPart request: %w", err) - } - req.ContentLength = int64(len(data)) - c.signV4(req, unsignedPayload) - resp, err := c.httpClient().Do(req) - if err != nil { - return "", fmt.Errorf("s3: UploadPart: %w", err) - } - defer resp.Body.Close() if !isHTTPSuccess(resp) { - return "", parseError(resp) + return nil, parseErrorFromBytes(resp.StatusCode, respBody) } - etag := resp.Header.Get("ETag") - return etag, nil + return respBody, nil } -// completeMultipartUpload finalizes a multipart upload with the given parts. -func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error { - log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts)) - var body bytes.Buffer - body.WriteString("") - for _, p := range parts { - fmt.Fprintf(&body, "%d%s", p.PartNumber, p.ETag) - } - body.WriteString("") - bodyBytes := body.Bytes() - payloadHash := sha256Hex(bodyBytes) +// doWithBody is like do, but sends a body with a computed SHA-256 payload hash and Content-Type: application/xml. +func (c *Client) doWithBody(ctx context.Context, method, reqURL string, bodyBytes []byte, op string) ([]byte, error) { + return c.doWithBodyAndHeaders(ctx, method, reqURL, bodyBytes, nil, op) +} - reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes)) +// doWithBodyAndHeaders is like doWithBody, but allows setting additional headers (e.g. Content-MD5). +func (c *Client) doWithBodyAndHeaders(ctx context.Context, method, reqURL string, bodyBytes []byte, headers map[string]string, op string) ([]byte, error) { + payloadHash := sha256Hex(bodyBytes) + req, err := http.NewRequestWithContext(ctx, method, reqURL, bytes.NewReader(bodyBytes)) if err != nil { - return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err) + return nil, fmt.Errorf("s3: %s request: %w", op, err) } req.ContentLength = int64(len(bodyBytes)) req.Header.Set("Content-Type", "application/xml") + for k, v := range headers { + req.Header.Set(k, v) + } c.signV4(req, payloadHash) - resp, err := c.httpClient().Do(req) + resp, err := c.http.Do(req) if err != nil { - return fmt.Errorf("s3: CompleteMultipartUpload: %w", err) + return nil, fmt.Errorf("s3: %s: %w", op, err) } - defer resp.Body.Close() - if !isHTTPSuccess(resp) { - return parseError(resp) - } - // Read response body to check for errors (S3 can return 200 with an error body) respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - if err != nil { - return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err) - } - // Check if the response contains an error - var errResp ErrorResponse - if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" { - errResp.StatusCode = resp.StatusCode - return &errResp - } - return nil -} - -// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. -func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) { - log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID) - reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) - req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) - if err != nil { - return - } - c.signV4(req, emptyPayloadHash) - resp, err := c.httpClient().Do(req) - if err != nil { - return - } resp.Body.Close() -} - -// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 -// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. -func (c *Client) signV4(req *http.Request, payloadHash string) { - now := time.Now().UTC() - datestamp := now.Format("20060102") - amzDate := now.Format("20060102T150405Z") - - // Required headers - req.Header.Set("Host", c.hostHeader()) - req.Header.Set("X-Amz-Date", amzDate) - req.Header.Set("X-Amz-Content-Sha256", payloadHash) - - // Canonical headers (all headers we set, sorted by lowercase key) - signedKeys := make([]string, 0, len(req.Header)) - canonHeaders := make(map[string]string, len(req.Header)) - for k := range req.Header { - lk := strings.ToLower(k) - signedKeys = append(signedKeys, lk) - canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k)) + if err != nil { + return nil, fmt.Errorf("s3: %s read: %w", op, err) } - sort.Strings(signedKeys) - signedHeadersStr := strings.Join(signedKeys, ";") - var chBuf strings.Builder - for _, k := range signedKeys { - chBuf.WriteString(k) - chBuf.WriteByte(':') - chBuf.WriteString(canonHeaders[k]) - chBuf.WriteByte('\n') + if !isHTTPSuccess(resp) { + return nil, parseErrorFromBytes(resp.StatusCode, respBody) } - - // Canonical request - canonicalRequest := strings.Join([]string{ - req.Method, - canonicalURI(req.URL), - canonicalQueryString(req.URL.Query()), - chBuf.String(), - signedHeadersStr, - payloadHash, - }, "\n") - - // String to sign - credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request" - stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest)) - - // Signing key - signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256( - []byte("AWS4"+c.SecretKey), []byte(datestamp)), - []byte(c.Region)), - []byte("s3")), - []byte("aws4_request")) - - signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) - req.Header.Set("Authorization", fmt.Sprintf( - "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", - c.AccessKey, credentialScope, signedHeadersStr, signature, - )) -} - -func (c *Client) httpClient() *http.Client { - if c.HTTPClient != nil { - return c.HTTPClient - } - return http.DefaultClient + return respBody, nil } // objectKey prepends the configured prefix to the given key. func (c *Client) objectKey(key string) string { - if c.Prefix != "" { - return c.Prefix + "/" + key + if c.config.Prefix != "" { + return c.config.Prefix + "/" + key } return key } @@ -548,18 +288,19 @@ func (c *Client) objectKey(key string) string { // prefixForList returns the prefix to use in ListObjectsV2 requests, // with a trailing slash so that only objects under the prefix directory are returned. func (c *Client) prefixForList() string { - if c.Prefix != "" { - return c.Prefix + "/" + if c.config.Prefix != "" { + return c.config.Prefix + "/" } return "" } -// bucketURL returns the base URL for bucket-level operations. -func (c *Client) bucketURL() string { - if c.PathStyle { - return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket) +// stripPrefix removes the configured prefix from a key returned by ListObjectsV2, +// so keys match what was passed to PutObject/GetObject/DeleteObjects. +func (c *Client) stripPrefix(key string) string { + if c.config.Prefix != "" { + return strings.TrimPrefix(key, c.config.Prefix+"/") } - return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint) + return key } // objectURL returns the full URL for an object (key should already include the prefix). @@ -569,13 +310,5 @@ func (c *Client) objectURL(key string) string { for i, seg := range segments { segments[i] = uriEncode(seg) } - return c.bucketURL() + "/" + strings.Join(segments, "/") -} - -// hostHeader returns the value for the Host header. -func (c *Client) hostHeader() string { - if c.PathStyle { - return c.Endpoint - } - return c.Bucket + "." + c.Endpoint + return c.config.BucketURL() + "/" + strings.Join(segments, "/") } diff --git a/s3/client_auth.go b/s3/client_auth.go new file mode 100644 index 00000000..ede971f3 --- /dev/null +++ b/s3/client_auth.go @@ -0,0 +1,68 @@ +package s3 + +import ( + "encoding/hex" + "fmt" + "net/http" + "sort" + "strings" + "time" +) + +// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256 +// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads. +func (c *Client) signV4(req *http.Request, payloadHash string) { + now := time.Now().UTC() + datestamp := now.Format("20060102") + amzDate := now.Format("20060102T150405Z") + + // Required headers + req.Header.Set("Host", c.config.HostHeader()) + req.Header.Set("X-Amz-Date", amzDate) + req.Header.Set("X-Amz-Content-Sha256", payloadHash) + + // Canonical headers (all headers we set, sorted by lowercase key) + signedKeys := make([]string, 0, len(req.Header)) + canonHeaders := make(map[string]string, len(req.Header)) + for k := range req.Header { + lk := strings.ToLower(k) + signedKeys = append(signedKeys, lk) + canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k)) + } + sort.Strings(signedKeys) + signedHeadersStr := strings.Join(signedKeys, ";") + var chBuf strings.Builder + for _, k := range signedKeys { + chBuf.WriteString(k) + chBuf.WriteByte(':') + chBuf.WriteString(canonHeaders[k]) + chBuf.WriteByte('\n') + } + + // Canonical request + canonicalRequest := strings.Join([]string{ + req.Method, + canonicalURI(req.URL), + canonicalQueryString(req.URL.Query()), + chBuf.String(), + signedHeadersStr, + payloadHash, + }, "\n") + + // String to sign + credentialScope := datestamp + "/" + c.config.Region + "/s3/aws4_request" + stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest)) + + // Signing key + signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256( + []byte("AWS4"+c.config.SecretKey), []byte(datestamp)), + []byte(c.config.Region)), + []byte("s3")), + []byte("aws4_request")) + + signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) + req.Header.Set("Authorization", fmt.Sprintf( + "AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", + c.config.AccessKey, credentialScope, signedHeadersStr, signature, + )) +} diff --git a/s3/client_multipart.go b/s3/client_multipart.go new file mode 100644 index 00000000..61d206b6 --- /dev/null +++ b/s3/client_multipart.go @@ -0,0 +1,188 @@ +package s3 + +import ( + "bytes" + "context" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "heckel.io/ntfy/v2/log" +) + +// ListMultipartUploads returns in-progress multipart uploads for the client's prefix. +// It paginates automatically, stopping after 10,000 pages as a safety valve. +func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) { + var all []MultipartUpload + var keyMarker, uploadIDMarker string + for page := 0; page < maxPages; page++ { + query := url.Values{"uploads": {""}} + if prefix := c.prefixForList(); prefix != "" { + query.Set("prefix", prefix) + } + if keyMarker != "" { + query.Set("key-marker", keyMarker) + query.Set("upload-id-marker", uploadIDMarker) + } + respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListMultipartUploads") + if err != nil { + return nil, err + } + var result listMultipartUploadsResult + if err := xml.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err) + } + for _, u := range result.Uploads { + var initiated time.Time + if u.Initiated != "" { + initiated, _ = time.Parse(time.RFC3339, u.Initiated) + } + all = append(all, MultipartUpload{ + Key: u.Key, + UploadID: u.UploadID, + Initiated: initiated, + }) + } + if !result.IsTruncated { + return all, nil + } + keyMarker = result.NextKeyMarker + uploadIDMarker = result.NextUploadIDMarker + } + return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages) +} + +// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated +// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads. +func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error { + uploads, err := c.ListMultipartUploads(ctx) + if err != nil { + return err + } + for _, u := range uploads { + if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) { + log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated) + c.abortMultipartUpload(ctx, u.Key, u.UploadID) + } + } + return nil +} + +// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize +// chunks, uploading each as a separate part. This allows uploading without knowing the total +// body size in advance. +func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error { + fullKey := c.objectKey(key) + + // Step 1: Initiate multipart upload + uploadID, err := c.initiateMultipartUpload(ctx, fullKey) + if err != nil { + return err + } + + // Step 2: Upload parts + var parts []completedPart + buf := make([]byte, partSize) + partNumber := 1 + for { + n, err := io.ReadFull(body, buf) + if n > 0 { + etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n]) + if uploadErr != nil { + c.abortMultipartUpload(ctx, fullKey, uploadID) + return uploadErr + } + parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag}) + partNumber++ + } + if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { + break + } + if err != nil { + c.abortMultipartUpload(ctx, fullKey, uploadID) + return fmt.Errorf("s3: PutObject read: %w", err) + } + } + + // Step 3: Complete multipart upload + return c.completeMultipartUpload(ctx, fullKey, uploadID, parts) +} + +// initiateMultipartUpload starts a new multipart upload and returns the upload ID. +func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) { + respBody, err := c.do(ctx, http.MethodPost, c.objectURL(fullKey)+"?uploads", nil, "InitiateMultipartUpload") + if err != nil { + return "", err + } + var result initiateMultipartUploadResult + if err := xml.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err) + } + log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID) + return result.UploadID, nil +} + +// uploadPart uploads a single part of a multipart upload and returns the ETag. +func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) { + log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data)) + reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data)) + if err != nil { + return "", fmt.Errorf("s3: UploadPart request: %w", err) + } + req.ContentLength = int64(len(data)) + c.signV4(req, unsignedPayload) + resp, err := c.http.Do(req) + if err != nil { + return "", fmt.Errorf("s3: UploadPart: %w", err) + } + defer resp.Body.Close() + if !isHTTPSuccess(resp) { + return "", parseError(resp) + } + etag := resp.Header.Get("ETag") + return etag, nil +} + +// completeMultipartUpload finalizes a multipart upload with the given parts. +func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error { + log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts)) + var body bytes.Buffer + body.WriteString("") + for _, p := range parts { + fmt.Fprintf(&body, "%d%s", p.PartNumber, p.ETag) + } + body.WriteString("") + respBody, err := c.doWithBody(ctx, http.MethodPost, + fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)), + body.Bytes(), "CompleteMultipartUpload") + if err != nil { + return err + } + // Check if the response contains an error (S3 can return 200 with an error body) + var errResp ErrorResponse + if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" { + return &errResp + } + return nil +} + +// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up. +func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) { + log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID) + reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil) + if err != nil { + return + } + c.signV4(req, emptyPayloadHash) + resp, err := c.http.Do(req) + if err != nil { + return + } + resp.Body.Close() +} diff --git a/s3/client_test.go b/s3/client_test.go index 8007601c..447f36d0 100644 --- a/s3/client_test.go +++ b/s3/client_test.go @@ -269,7 +269,7 @@ func (m *mockS3Server) objectCount() int { func newTestClient(server *httptest.Server, bucket, prefix string) *Client { // httptest.NewTLSServer URL is like "https://127.0.0.1:PORT" host := strings.TrimPrefix(server.URL, "https://") - return &Client{ + return New(&Config{ AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", @@ -278,7 +278,7 @@ func newTestClient(server *httptest.Server, bucket, prefix string) *Client { Prefix: prefix, PathStyle: true, HTTPClient: server.Client(), - } + }) } // --- URL parsing tests --- @@ -363,49 +363,49 @@ func TestParseURL_EmptyBucket(t *testing.T) { // --- Unit tests: URL construction --- -func TestClient_BucketURL_PathStyle(t *testing.T) { - c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} - require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL()) +func TestConfig_BucketURL_PathStyle(t *testing.T) { + c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + require.Equal(t, "https://s3.example.com/my-bucket", c.BucketURL()) } -func TestClient_BucketURL_VirtualHosted(t *testing.T) { - c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} - require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL()) +func TestConfig_BucketURL_VirtualHosted(t *testing.T) { + c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.BucketURL()) } func TestClient_ObjectURL_PathStyle(t *testing.T) { - c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + c := &Client{config: &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}} require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj")) } func TestClient_ObjectURL_VirtualHosted(t *testing.T) { - c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + c := &Client{config: &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}} require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj")) } -func TestClient_HostHeader_PathStyle(t *testing.T) { - c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} - require.Equal(t, "s3.example.com", c.hostHeader()) +func TestConfig_HostHeader_PathStyle(t *testing.T) { + c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true} + require.Equal(t, "s3.example.com", c.HostHeader()) } -func TestClient_HostHeader_VirtualHosted(t *testing.T) { - c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} - require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader()) +func TestConfig_HostHeader_VirtualHosted(t *testing.T) { + c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false} + require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.HostHeader()) } func TestClient_ObjectKey(t *testing.T) { - c := &Client{Prefix: "attachments"} + c := &Client{config: &Config{Prefix: "attachments"}} require.Equal(t, "attachments/file123", c.objectKey("file123")) - c2 := &Client{Prefix: ""} + c2 := &Client{config: &Config{Prefix: ""}} require.Equal(t, "file123", c2.objectKey("file123")) } func TestClient_PrefixForList(t *testing.T) { - c := &Client{Prefix: "attachments"} + c := &Client{config: &Config{Prefix: "attachments"}} require.Equal(t, "attachments/", c.prefixForList()) - c2 := &Client{Prefix: ""} + c2 := &Client{config: &Config{Prefix: ""}} require.Equal(t, "", c2.prefixForList()) } @@ -512,13 +512,13 @@ func TestClient_ListObjects(t *testing.T) { require.Nil(t, err) // List with prefix client: should only see 3 - result, err := client.ListObjects(ctx, "", 0) + result, err := client.listObjects(ctx, "", 0) require.Nil(t, err) require.Len(t, result.Objects, 3) require.False(t, result.IsTruncated) // List with no-prefix client: should see all 4 - result, err = clientNoPrefix.ListObjects(ctx, "", 0) + result, err = clientNoPrefix.listObjects(ctx, "", 0) require.Nil(t, err) require.Len(t, result.Objects, 4) } @@ -537,20 +537,20 @@ func TestClient_ListObjects_Pagination(t *testing.T) { } // List with max-keys=2 - result, err := client.ListObjects(ctx, "", 2) + result, err := client.listObjects(ctx, "", 2) require.Nil(t, err) require.Len(t, result.Objects, 2) require.True(t, result.IsTruncated) require.NotEmpty(t, result.NextContinuationToken) // Get next page - result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2) + result2, err := client.listObjects(ctx, result.NextContinuationToken, 2) require.Nil(t, err) require.Len(t, result2.Objects, 2) require.True(t, result2.IsTruncated) // Get last page - result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2) + result3, err := client.listObjects(ctx, result2.NextContinuationToken, 2) require.Nil(t, err) require.Len(t, result3.Objects, 1) require.False(t, result3.IsTruncated) @@ -744,7 +744,7 @@ func TestClient_RealBucket(t *testing.T) { prefix = "ntfy-s3-test" } - client := &Client{ + client := New(&Config{ AccessKey: accessKey, SecretKey: secretKey, Region: region, @@ -752,7 +752,7 @@ func TestClient_RealBucket(t *testing.T) { Bucket: bucket, Prefix: prefix, PathStyle: pathStyle, - } + }) ctx := context.Background() @@ -762,8 +762,7 @@ func TestClient_RealBucket(t *testing.T) { if len(existing) > 0 { keys := make([]string, len(existing)) for i, obj := range existing { - // Strip the prefix since DeleteObjects will re-add it - keys[i] = strings.TrimPrefix(obj.Key, prefix+"/") + keys[i] = obj.Key } // Batch delete in groups of 1000 for i := 0; i < len(keys); i += 1000 { @@ -807,7 +806,7 @@ func TestClient_RealBucket(t *testing.T) { t.Run("ListObjects", func(t *testing.T) { // Use a sub-prefix client for isolation - listClient := &Client{ + listClient := New(&Config{ AccessKey: accessKey, SecretKey: secretKey, Region: region, @@ -815,7 +814,7 @@ func TestClient_RealBucket(t *testing.T) { Bucket: bucket, Prefix: prefix + "/list-test", PathStyle: pathStyle, - } + }) // Put 10 objects for i := 0; i < 10; i++ { diff --git a/s3/types.go b/s3/types.go index f78c4a8b..5fec7c78 100644 --- a/s3/types.go +++ b/s3/types.go @@ -2,18 +2,36 @@ package s3 import ( "fmt" + "net/http" "time" ) // Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string. type Config struct { - Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com" - PathStyle bool - Bucket string - Prefix string - Region string - AccessKey string - SecretKey string + Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com" + PathStyle bool + Bucket string + Prefix string + Region string + AccessKey string + SecretKey string + HTTPClient *http.Client // if nil, http.DefaultClient is used +} + +// bucketURL returns the base URL for bucket-level operations. +func (c *Config) BucketURL() string { + if c.PathStyle { + return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket) + } + return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint) +} + +// hostHeader returns the value for the Host header. +func (c *Config) HostHeader() string { + if c.PathStyle { + return c.Endpoint + } + return c.Bucket + "." + c.Endpoint } // Object represents an S3 object returned by list operations. @@ -23,8 +41,8 @@ type Object struct { LastModified time.Time } -// ListResult holds the response from a ListObjectsV2 call. -type ListResult struct { +// listResult holds the response from a single ListObjectsV2 page. +type listResult struct { Objects []Object IsTruncated bool NextContinuationToken string @@ -78,10 +96,10 @@ type MultipartUpload struct { // listMultipartUploadsResult is the XML response from S3 ListMultipartUploads type listMultipartUploadsResult struct { - Uploads []listUpload `xml:"Upload"` - IsTruncated bool `xml:"IsTruncated"` - NextKeyMarker string `xml:"NextKeyMarker"` - NextUploadIDMarker string `xml:"NextUploadIdMarker"` + Uploads []listUpload `xml:"Upload"` + IsTruncated bool `xml:"IsTruncated"` + NextKeyMarker string `xml:"NextKeyMarker"` + NextUploadIDMarker string `xml:"NextUploadIdMarker"` } type listUpload struct { diff --git a/s3/util_test.go b/s3/util_test.go index d30c5664..3f08911d 100644 --- a/s3/util_test.go +++ b/s3/util_test.go @@ -105,13 +105,13 @@ func TestHmacSHA256(t *testing.T) { } func TestSignV4_SetsRequiredHeaders(t *testing.T) { - c := &Client{ + c := &Client{config: &Config{ AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", - } + }} req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil) c.signV4(req, emptyPayloadHash) @@ -131,13 +131,13 @@ func TestSignV4_SetsRequiredHeaders(t *testing.T) { } func TestSignV4_UnsignedPayload(t *testing.T) { - c := &Client{ + c := &Client{config: &Config{ AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", - } + }} req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil) c.signV4(req, unsignedPayload) @@ -146,8 +146,8 @@ func TestSignV4_UnsignedPayload(t *testing.T) { } func TestSignV4_DifferentRegions(t *testing.T) { - c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"} - c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"} + c1 := &Client{config: &Config{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}} + c2 := &Client{config: &Config{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}} req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil) c1.signV4(req1, emptyPayloadHash)