Move stuff around
This commit is contained in:
parent
a1b403d23c
commit
d86e20173c
8 changed files with 396 additions and 396 deletions
|
|
@ -3,7 +3,6 @@ package attachment
|
|||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
|
|
@ -38,15 +37,10 @@ func (b *s3Backend) List() ([]object, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
prefix := b.client.Prefix
|
||||
result := make([]object, 0, len(objects))
|
||||
for _, obj := range objects {
|
||||
id := obj.Key
|
||||
if prefix != "" {
|
||||
id = strings.TrimPrefix(id, prefix+"/")
|
||||
}
|
||||
result = append(result, object{
|
||||
ID: id,
|
||||
ID: obj.Key,
|
||||
Size: obj.Size,
|
||||
LastModified: obj.LastModified,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ func TestS3Store_Sync_SkipsRecentFiles(t *testing.T) {
|
|||
func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string, totalSizeLimit int64) *Store {
|
||||
t.Helper()
|
||||
host := strings.TrimPrefix(server.URL, "https://")
|
||||
backend := newS3Backend(&s3.Client{
|
||||
backend := newS3Backend(s3.New(&s3.Config{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
|
|
@ -206,7 +206,7 @@ func newTestS3Store(t *testing.T, server *httptest.Server, bucket, prefix string
|
|||
Prefix: prefix,
|
||||
PathStyle: true,
|
||||
HTTPClient: server.Client(),
|
||||
})
|
||||
}))
|
||||
cache, err := newStore(backend, totalSizeLimit, nil)
|
||||
require.Nil(t, err)
|
||||
t.Cleanup(func() { cache.Close() })
|
||||
|
|
|
|||
407
s3/client.go
407
s3/client.go
|
|
@ -8,14 +8,12 @@ import (
|
|||
"context"
|
||||
"crypto/md5" //nolint:gosec // MD5 is required by the S3 protocol for Content-MD5 headers
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -33,26 +31,19 @@ const (
|
|||
//
|
||||
// Fields must not be modified after the Client is passed to any method or goroutine.
|
||||
type Client struct {
|
||||
AccessKey string // AWS access key ID
|
||||
SecretKey string // AWS secret access key
|
||||
Region string // e.g. "us-east-1"
|
||||
Endpoint string // host[:port] only, e.g. "s3.amazonaws.com" or "nyc3.digitaloceanspaces.com"
|
||||
Bucket string // S3 bucket name
|
||||
Prefix string // optional key prefix (e.g. "attachments"); prepended to all keys automatically
|
||||
PathStyle bool // if true, use path-style addressing; otherwise virtual-hosted-style
|
||||
HTTPClient *http.Client // if nil, http.DefaultClient is used
|
||||
config *Config
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// New creates a new S3 client from the given Config.
|
||||
func New(config *Config) *Client {
|
||||
httpClient := config.HTTPClient
|
||||
if httpClient == nil {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
return &Client{
|
||||
AccessKey: config.AccessKey,
|
||||
SecretKey: config.SecretKey,
|
||||
Region: config.Region,
|
||||
Endpoint: config.Endpoint,
|
||||
Bucket: config.Bucket,
|
||||
Prefix: config.Prefix,
|
||||
PathStyle: config.PathStyle,
|
||||
config: config,
|
||||
http: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -87,7 +78,7 @@ func (c *Client) GetObject(ctx context.Context, key string) (io.ReadCloser, int6
|
|||
return nil, 0, fmt.Errorf("s3: GetObject request: %w", err)
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("s3: GetObject: %w", err)
|
||||
}
|
||||
|
|
@ -116,35 +107,18 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
|
|||
}
|
||||
body.WriteString("</Delete>")
|
||||
bodyBytes := body.Bytes()
|
||||
payloadHash := sha256Hex(bodyBytes)
|
||||
|
||||
// Content-MD5 is required by the S3 protocol for DeleteObjects requests.
|
||||
md5Sum := md5.Sum(bodyBytes) //nolint:gosec
|
||||
contentMD5 := base64.StdEncoding.EncodeToString(md5Sum[:])
|
||||
|
||||
reqURL := c.bucketURL() + "?delete="
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
|
||||
respBody, err := c.doWithBodyAndHeaders(ctx, http.MethodPost, c.config.BucketURL()+"?delete=", bodyBytes,
|
||||
map[string]string{"Content-MD5": contentMD5}, "DeleteObjects")
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
req.Header.Set("Content-MD5", contentMD5)
|
||||
c.signV4(req, payloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !isHTTPSuccess(resp) {
|
||||
return parseError(resp)
|
||||
return err
|
||||
}
|
||||
|
||||
// S3 may return HTTP 200 with per-key errors in the response body
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: DeleteObjects read response: %w", err)
|
||||
}
|
||||
var result deleteResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return nil // If we can't parse, assume success (Quiet mode returns empty body on success)
|
||||
|
|
@ -159,9 +133,9 @@ func (c *Client) DeleteObjects(ctx context.Context, keys []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ListObjects performs a single ListObjectsV2 request using the client's configured prefix.
|
||||
// listObjects performs a single ListObjectsV2 request using the client's configured prefix.
|
||||
// Use continuationToken for pagination. Set maxKeys to 0 for the server default (typically 1000).
|
||||
func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxKeys int) (*ListResult, error) {
|
||||
func (c *Client) listObjects(ctx context.Context, continuationToken string, maxKeys int) (*listResult, error) {
|
||||
log.Tag(tagS3Client).Debug("ListObjects continuation=%s maxKeys=%d", continuationToken, maxKeys)
|
||||
query := url.Values{"list-type": {"2"}}
|
||||
if prefix := c.prefixForList(); prefix != "" {
|
||||
|
|
@ -173,22 +147,9 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK
|
|||
if maxKeys > 0 {
|
||||
query.Set("max-keys", strconv.Itoa(maxKeys))
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
|
||||
respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListObjects")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects request: %w", err)
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects: %w", err)
|
||||
}
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListObjects read: %w", err)
|
||||
}
|
||||
if !isHTTPSuccess(resp) {
|
||||
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
|
||||
return nil, err
|
||||
}
|
||||
var result listObjectsV2Response
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
|
|
@ -206,7 +167,7 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK
|
|||
LastModified: lastModified,
|
||||
}
|
||||
}
|
||||
return &ListResult{
|
||||
return &listResult{
|
||||
Objects: objects,
|
||||
IsTruncated: result.IsTruncated,
|
||||
NextContinuationToken: result.NextContinuationToken,
|
||||
|
|
@ -214,16 +175,21 @@ func (c *Client) ListObjects(ctx context.Context, continuationToken string, maxK
|
|||
}
|
||||
|
||||
// ListAllObjects returns all objects under the client's configured prefix by paginating through
|
||||
// ListObjectsV2 results automatically. It stops after 10,000 pages as a safety valve.
|
||||
// ListObjectsV2 results automatically. Keys in the returned objects have the prefix stripped,
|
||||
// so they match the keys used with PutObject/GetObject/DeleteObjects. It stops after 10,000
|
||||
// pages as a safety valve.
|
||||
func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
|
||||
var all []Object
|
||||
var token string
|
||||
for page := 0; page < maxPages; page++ {
|
||||
result, err := c.ListObjects(ctx, token, 0)
|
||||
result, err := c.listObjects(ctx, token, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
all = append(all, result.Objects...)
|
||||
for _, obj := range result.Objects {
|
||||
obj.Key = c.stripPrefix(obj.Key)
|
||||
all = append(all, obj)
|
||||
}
|
||||
if !result.IsTruncated {
|
||||
return all, nil
|
||||
}
|
||||
|
|
@ -232,77 +198,6 @@ func (c *Client) ListAllObjects(ctx context.Context) ([]Object, error) {
|
|||
return nil, fmt.Errorf("s3: ListAllObjects exceeded %d pages", maxPages)
|
||||
}
|
||||
|
||||
// ListMultipartUploads returns in-progress multipart uploads for the client's prefix.
|
||||
// It paginates automatically, stopping after 10,000 pages as a safety valve.
|
||||
func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) {
|
||||
var all []MultipartUpload
|
||||
var keyMarker, uploadIDMarker string
|
||||
for page := 0; page < maxPages; page++ {
|
||||
query := url.Values{"uploads": {""}}
|
||||
if prefix := c.prefixForList(); prefix != "" {
|
||||
query.Set("prefix", prefix)
|
||||
}
|
||||
if keyMarker != "" {
|
||||
query.Set("key-marker", keyMarker)
|
||||
query.Set("upload-id-marker", uploadIDMarker)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.bucketURL()+"?"+query.Encode(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListMultipartUploads request: %w", err)
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListMultipartUploads: %w", err)
|
||||
}
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: ListMultipartUploads read: %w", err)
|
||||
}
|
||||
if !isHTTPSuccess(resp) {
|
||||
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
|
||||
}
|
||||
var result listMultipartUploadsResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err)
|
||||
}
|
||||
for _, u := range result.Uploads {
|
||||
var initiated time.Time
|
||||
if u.Initiated != "" {
|
||||
initiated, _ = time.Parse(time.RFC3339, u.Initiated)
|
||||
}
|
||||
all = append(all, MultipartUpload{
|
||||
Key: u.Key,
|
||||
UploadID: u.UploadID,
|
||||
Initiated: initiated,
|
||||
})
|
||||
}
|
||||
if !result.IsTruncated {
|
||||
return all, nil
|
||||
}
|
||||
keyMarker = result.NextKeyMarker
|
||||
uploadIDMarker = result.NextUploadIDMarker
|
||||
}
|
||||
return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages)
|
||||
}
|
||||
|
||||
// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated
|
||||
// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads.
|
||||
func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error {
|
||||
uploads, err := c.ListMultipartUploads(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, u := range uploads {
|
||||
if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) {
|
||||
log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated)
|
||||
c.abortMultipartUpload(ctx, u.Key, u.UploadID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// putObject uploads a body with known size using a simple PUT with UNSIGNED-PAYLOAD.
|
||||
func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size int64) error {
|
||||
fullKey := c.objectKey(key)
|
||||
|
|
@ -312,235 +207,80 @@ func (c *Client) putObject(ctx context.Context, key string, body io.Reader, size
|
|||
}
|
||||
req.ContentLength = size
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: PutObject: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
resp.Body.Close()
|
||||
if !isHTTPSuccess(resp) {
|
||||
return parseError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
|
||||
// chunks, uploading each as a separate part. This allows uploading without knowing the total
|
||||
// body size in advance.
|
||||
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
|
||||
fullKey := c.objectKey(key)
|
||||
|
||||
// Step 1: Initiate multipart upload
|
||||
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
|
||||
// do creates a request, signs it with an empty payload, executes it, reads the response body,
|
||||
// and checks for errors. It is used for bodiless GET/POST requests.
|
||||
func (c *Client) do(ctx context.Context, method, reqURL string, body io.Reader, op string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, body)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, fmt.Errorf("s3: %s request: %w", op, err)
|
||||
}
|
||||
|
||||
// Step 2: Upload parts
|
||||
var parts []completedPart
|
||||
buf := make([]byte, partSize)
|
||||
partNumber := 1
|
||||
for {
|
||||
n, err := io.ReadFull(body, buf)
|
||||
if n > 0 {
|
||||
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
|
||||
if uploadErr != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return uploadErr
|
||||
}
|
||||
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
|
||||
partNumber++
|
||||
}
|
||||
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
}
|
||||
if body == nil {
|
||||
req.ContentLength = 0
|
||||
}
|
||||
|
||||
// Step 3: Complete multipart upload
|
||||
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
|
||||
}
|
||||
|
||||
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
|
||||
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
|
||||
reqURL := c.objectURL(fullKey) + "?uploads"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload request: %w", err)
|
||||
}
|
||||
req.ContentLength = 0
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !isHTTPSuccess(resp) {
|
||||
return "", parseError(resp)
|
||||
return nil, fmt.Errorf("s3: %s: %w", op, err)
|
||||
}
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload read: %w", err)
|
||||
return nil, fmt.Errorf("s3: %s read: %w", op, err)
|
||||
}
|
||||
var result initiateMultipartUploadResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
|
||||
}
|
||||
log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID)
|
||||
return result.UploadID, nil
|
||||
}
|
||||
|
||||
// uploadPart uploads a single part of a multipart upload and returns the ETag.
|
||||
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
|
||||
log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data))
|
||||
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(data))
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !isHTTPSuccess(resp) {
|
||||
return "", parseError(resp)
|
||||
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
|
||||
}
|
||||
etag := resp.Header.Get("ETag")
|
||||
return etag, nil
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
// completeMultipartUpload finalizes a multipart upload with the given parts.
|
||||
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
|
||||
log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts))
|
||||
var body bytes.Buffer
|
||||
body.WriteString("<CompleteMultipartUpload>")
|
||||
for _, p := range parts {
|
||||
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
|
||||
}
|
||||
body.WriteString("</CompleteMultipartUpload>")
|
||||
bodyBytes := body.Bytes()
|
||||
payloadHash := sha256Hex(bodyBytes)
|
||||
// doWithBody is like do, but sends a body with a computed SHA-256 payload hash and Content-Type: application/xml.
|
||||
func (c *Client) doWithBody(ctx context.Context, method, reqURL string, bodyBytes []byte, op string) ([]byte, error) {
|
||||
return c.doWithBodyAndHeaders(ctx, method, reqURL, bodyBytes, nil, op)
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(bodyBytes))
|
||||
// doWithBodyAndHeaders is like doWithBody, but allows setting additional headers (e.g. Content-MD5).
|
||||
func (c *Client) doWithBodyAndHeaders(ctx context.Context, method, reqURL string, bodyBytes []byte, headers map[string]string, op string) ([]byte, error) {
|
||||
payloadHash := sha256Hex(bodyBytes)
|
||||
req, err := http.NewRequestWithContext(ctx, method, reqURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload request: %w", err)
|
||||
return nil, fmt.Errorf("s3: %s request: %w", op, err)
|
||||
}
|
||||
req.ContentLength = int64(len(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/xml")
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
c.signV4(req, payloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload: %w", err)
|
||||
return nil, fmt.Errorf("s3: %s: %w", op, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !isHTTPSuccess(resp) {
|
||||
return parseError(resp)
|
||||
}
|
||||
// Read response body to check for errors (S3 can return 200 with an error body)
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("s3: CompleteMultipartUpload read: %w", err)
|
||||
}
|
||||
// Check if the response contains an error
|
||||
var errResp ErrorResponse
|
||||
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
|
||||
errResp.StatusCode = resp.StatusCode
|
||||
return &errResp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
|
||||
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
|
||||
log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID)
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.httpClient().Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
|
||||
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
|
||||
func (c *Client) signV4(req *http.Request, payloadHash string) {
|
||||
now := time.Now().UTC()
|
||||
datestamp := now.Format("20060102")
|
||||
amzDate := now.Format("20060102T150405Z")
|
||||
|
||||
// Required headers
|
||||
req.Header.Set("Host", c.hostHeader())
|
||||
req.Header.Set("X-Amz-Date", amzDate)
|
||||
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
|
||||
|
||||
// Canonical headers (all headers we set, sorted by lowercase key)
|
||||
signedKeys := make([]string, 0, len(req.Header))
|
||||
canonHeaders := make(map[string]string, len(req.Header))
|
||||
for k := range req.Header {
|
||||
lk := strings.ToLower(k)
|
||||
signedKeys = append(signedKeys, lk)
|
||||
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s3: %s read: %w", op, err)
|
||||
}
|
||||
sort.Strings(signedKeys)
|
||||
signedHeadersStr := strings.Join(signedKeys, ";")
|
||||
var chBuf strings.Builder
|
||||
for _, k := range signedKeys {
|
||||
chBuf.WriteString(k)
|
||||
chBuf.WriteByte(':')
|
||||
chBuf.WriteString(canonHeaders[k])
|
||||
chBuf.WriteByte('\n')
|
||||
if !isHTTPSuccess(resp) {
|
||||
return nil, parseErrorFromBytes(resp.StatusCode, respBody)
|
||||
}
|
||||
|
||||
// Canonical request
|
||||
canonicalRequest := strings.Join([]string{
|
||||
req.Method,
|
||||
canonicalURI(req.URL),
|
||||
canonicalQueryString(req.URL.Query()),
|
||||
chBuf.String(),
|
||||
signedHeadersStr,
|
||||
payloadHash,
|
||||
}, "\n")
|
||||
|
||||
// String to sign
|
||||
credentialScope := datestamp + "/" + c.Region + "/s3/aws4_request"
|
||||
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
|
||||
|
||||
// Signing key
|
||||
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
|
||||
[]byte("AWS4"+c.SecretKey), []byte(datestamp)),
|
||||
[]byte(c.Region)),
|
||||
[]byte("s3")),
|
||||
[]byte("aws4_request"))
|
||||
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
|
||||
req.Header.Set("Authorization", fmt.Sprintf(
|
||||
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
c.AccessKey, credentialScope, signedHeadersStr, signature,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *Client) httpClient() *http.Client {
|
||||
if c.HTTPClient != nil {
|
||||
return c.HTTPClient
|
||||
}
|
||||
return http.DefaultClient
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
// objectKey prepends the configured prefix to the given key.
|
||||
func (c *Client) objectKey(key string) string {
|
||||
if c.Prefix != "" {
|
||||
return c.Prefix + "/" + key
|
||||
if c.config.Prefix != "" {
|
||||
return c.config.Prefix + "/" + key
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
|
@ -548,18 +288,19 @@ func (c *Client) objectKey(key string) string {
|
|||
// prefixForList returns the prefix to use in ListObjectsV2 requests,
|
||||
// with a trailing slash so that only objects under the prefix directory are returned.
|
||||
func (c *Client) prefixForList() string {
|
||||
if c.Prefix != "" {
|
||||
return c.Prefix + "/"
|
||||
if c.config.Prefix != "" {
|
||||
return c.config.Prefix + "/"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// bucketURL returns the base URL for bucket-level operations.
|
||||
func (c *Client) bucketURL() string {
|
||||
if c.PathStyle {
|
||||
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
|
||||
// stripPrefix removes the configured prefix from a key returned by ListObjectsV2,
|
||||
// so keys match what was passed to PutObject/GetObject/DeleteObjects.
|
||||
func (c *Client) stripPrefix(key string) string {
|
||||
if c.config.Prefix != "" {
|
||||
return strings.TrimPrefix(key, c.config.Prefix+"/")
|
||||
}
|
||||
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
|
||||
return key
|
||||
}
|
||||
|
||||
// objectURL returns the full URL for an object (key should already include the prefix).
|
||||
|
|
@ -569,13 +310,5 @@ func (c *Client) objectURL(key string) string {
|
|||
for i, seg := range segments {
|
||||
segments[i] = uriEncode(seg)
|
||||
}
|
||||
return c.bucketURL() + "/" + strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
// hostHeader returns the value for the Host header.
|
||||
func (c *Client) hostHeader() string {
|
||||
if c.PathStyle {
|
||||
return c.Endpoint
|
||||
}
|
||||
return c.Bucket + "." + c.Endpoint
|
||||
return c.config.BucketURL() + "/" + strings.Join(segments, "/")
|
||||
}
|
||||
|
|
|
|||
68
s3/client_auth.go
Normal file
68
s3/client_auth.go
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// signV4 signs req in place using AWS Signature V4. payloadHash is the hex-encoded SHA-256
|
||||
// of the request body, or the literal string "UNSIGNED-PAYLOAD" for streaming uploads.
|
||||
func (c *Client) signV4(req *http.Request, payloadHash string) {
|
||||
now := time.Now().UTC()
|
||||
datestamp := now.Format("20060102")
|
||||
amzDate := now.Format("20060102T150405Z")
|
||||
|
||||
// Required headers
|
||||
req.Header.Set("Host", c.config.HostHeader())
|
||||
req.Header.Set("X-Amz-Date", amzDate)
|
||||
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
|
||||
|
||||
// Canonical headers (all headers we set, sorted by lowercase key)
|
||||
signedKeys := make([]string, 0, len(req.Header))
|
||||
canonHeaders := make(map[string]string, len(req.Header))
|
||||
for k := range req.Header {
|
||||
lk := strings.ToLower(k)
|
||||
signedKeys = append(signedKeys, lk)
|
||||
canonHeaders[lk] = strings.TrimSpace(req.Header.Get(k))
|
||||
}
|
||||
sort.Strings(signedKeys)
|
||||
signedHeadersStr := strings.Join(signedKeys, ";")
|
||||
var chBuf strings.Builder
|
||||
for _, k := range signedKeys {
|
||||
chBuf.WriteString(k)
|
||||
chBuf.WriteByte(':')
|
||||
chBuf.WriteString(canonHeaders[k])
|
||||
chBuf.WriteByte('\n')
|
||||
}
|
||||
|
||||
// Canonical request
|
||||
canonicalRequest := strings.Join([]string{
|
||||
req.Method,
|
||||
canonicalURI(req.URL),
|
||||
canonicalQueryString(req.URL.Query()),
|
||||
chBuf.String(),
|
||||
signedHeadersStr,
|
||||
payloadHash,
|
||||
}, "\n")
|
||||
|
||||
// String to sign
|
||||
credentialScope := datestamp + "/" + c.config.Region + "/s3/aws4_request"
|
||||
stringToSign := "AWS4-HMAC-SHA256\n" + amzDate + "\n" + credentialScope + "\n" + sha256Hex([]byte(canonicalRequest))
|
||||
|
||||
// Signing key
|
||||
signingKey := hmacSHA256(hmacSHA256(hmacSHA256(hmacSHA256(
|
||||
[]byte("AWS4"+c.config.SecretKey), []byte(datestamp)),
|
||||
[]byte(c.config.Region)),
|
||||
[]byte("s3")),
|
||||
[]byte("aws4_request"))
|
||||
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
|
||||
req.Header.Set("Authorization", fmt.Sprintf(
|
||||
"AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
c.config.AccessKey, credentialScope, signedHeadersStr, signature,
|
||||
))
|
||||
}
|
||||
188
s3/client_multipart.go
Normal file
188
s3/client_multipart.go
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
package s3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"heckel.io/ntfy/v2/log"
|
||||
)
|
||||
|
||||
// ListMultipartUploads returns in-progress multipart uploads for the client's prefix.
|
||||
// It paginates automatically, stopping after 10,000 pages as a safety valve.
|
||||
func (c *Client) ListMultipartUploads(ctx context.Context) ([]MultipartUpload, error) {
|
||||
var all []MultipartUpload
|
||||
var keyMarker, uploadIDMarker string
|
||||
for page := 0; page < maxPages; page++ {
|
||||
query := url.Values{"uploads": {""}}
|
||||
if prefix := c.prefixForList(); prefix != "" {
|
||||
query.Set("prefix", prefix)
|
||||
}
|
||||
if keyMarker != "" {
|
||||
query.Set("key-marker", keyMarker)
|
||||
query.Set("upload-id-marker", uploadIDMarker)
|
||||
}
|
||||
respBody, err := c.do(ctx, http.MethodGet, c.config.BucketURL()+"?"+query.Encode(), nil, "ListMultipartUploads")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result listMultipartUploadsResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("s3: ListMultipartUploads XML: %w", err)
|
||||
}
|
||||
for _, u := range result.Uploads {
|
||||
var initiated time.Time
|
||||
if u.Initiated != "" {
|
||||
initiated, _ = time.Parse(time.RFC3339, u.Initiated)
|
||||
}
|
||||
all = append(all, MultipartUpload{
|
||||
Key: u.Key,
|
||||
UploadID: u.UploadID,
|
||||
Initiated: initiated,
|
||||
})
|
||||
}
|
||||
if !result.IsTruncated {
|
||||
return all, nil
|
||||
}
|
||||
keyMarker = result.NextKeyMarker
|
||||
uploadIDMarker = result.NextUploadIDMarker
|
||||
}
|
||||
return nil, fmt.Errorf("s3: ListMultipartUploads exceeded %d pages", maxPages)
|
||||
}
|
||||
|
||||
// AbortIncompleteUploads lists all in-progress multipart uploads and aborts those initiated
|
||||
// before the given cutoff time. This cleans up orphaned upload parts from interrupted uploads.
|
||||
func (c *Client) AbortIncompleteUploads(ctx context.Context, cutoff time.Time) error {
|
||||
uploads, err := c.ListMultipartUploads(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, u := range uploads {
|
||||
if !u.Initiated.IsZero() && u.Initiated.Before(cutoff) {
|
||||
log.Tag(tagS3Client).Debug("DeleteIncomplete key=%s uploadId=%s initiated=%s", u.Key, u.UploadID, u.Initiated)
|
||||
c.abortMultipartUpload(ctx, u.Key, u.UploadID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// putObjectMultipart uploads body using S3 multipart upload. It reads the body in partSize
|
||||
// chunks, uploading each as a separate part. This allows uploading without knowing the total
|
||||
// body size in advance.
|
||||
func (c *Client) putObjectMultipart(ctx context.Context, key string, body io.Reader) error {
|
||||
fullKey := c.objectKey(key)
|
||||
|
||||
// Step 1: Initiate multipart upload
|
||||
uploadID, err := c.initiateMultipartUpload(ctx, fullKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 2: Upload parts
|
||||
var parts []completedPart
|
||||
buf := make([]byte, partSize)
|
||||
partNumber := 1
|
||||
for {
|
||||
n, err := io.ReadFull(body, buf)
|
||||
if n > 0 {
|
||||
etag, uploadErr := c.uploadPart(ctx, fullKey, uploadID, partNumber, buf[:n])
|
||||
if uploadErr != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return uploadErr
|
||||
}
|
||||
parts = append(parts, completedPart{PartNumber: partNumber, ETag: etag})
|
||||
partNumber++
|
||||
}
|
||||
if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.abortMultipartUpload(ctx, fullKey, uploadID)
|
||||
return fmt.Errorf("s3: PutObject read: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Complete multipart upload
|
||||
return c.completeMultipartUpload(ctx, fullKey, uploadID, parts)
|
||||
}
|
||||
|
||||
// initiateMultipartUpload starts a new multipart upload and returns the upload ID.
|
||||
func (c *Client) initiateMultipartUpload(ctx context.Context, fullKey string) (string, error) {
|
||||
respBody, err := c.do(ctx, http.MethodPost, c.objectURL(fullKey)+"?uploads", nil, "InitiateMultipartUpload")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var result initiateMultipartUploadResult
|
||||
if err := xml.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("s3: InitiateMultipartUpload XML: %w", err)
|
||||
}
|
||||
log.Tag(tagS3Client).Debug("InitiateMultipartUpload key=%s uploadId=%s", fullKey, result.UploadID)
|
||||
return result.UploadID, nil
|
||||
}
|
||||
|
||||
// uploadPart uploads a single part of a multipart upload and returns the ETag.
|
||||
func (c *Client) uploadPart(ctx context.Context, fullKey, uploadID string, partNumber int, data []byte) (string, error) {
|
||||
log.Tag(tagS3Client).Debug("UploadPart key=%s part=%d size=%d", fullKey, partNumber, len(data))
|
||||
reqURL := fmt.Sprintf("%s?partNumber=%d&uploadId=%s", c.objectURL(fullKey), partNumber, url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, reqURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart request: %w", err)
|
||||
}
|
||||
req.ContentLength = int64(len(data))
|
||||
c.signV4(req, unsignedPayload)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("s3: UploadPart: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !isHTTPSuccess(resp) {
|
||||
return "", parseError(resp)
|
||||
}
|
||||
etag := resp.Header.Get("ETag")
|
||||
return etag, nil
|
||||
}
|
||||
|
||||
// completeMultipartUpload finalizes a multipart upload with the given parts.
|
||||
func (c *Client) completeMultipartUpload(ctx context.Context, fullKey, uploadID string, parts []completedPart) error {
|
||||
log.Tag(tagS3Client).Debug("CompleteMultipartUpload key=%s uploadId=%s parts=%d", fullKey, uploadID, len(parts))
|
||||
var body bytes.Buffer
|
||||
body.WriteString("<CompleteMultipartUpload>")
|
||||
for _, p := range parts {
|
||||
fmt.Fprintf(&body, "<Part><PartNumber>%d</PartNumber><ETag>%s</ETag></Part>", p.PartNumber, p.ETag)
|
||||
}
|
||||
body.WriteString("</CompleteMultipartUpload>")
|
||||
respBody, err := c.doWithBody(ctx, http.MethodPost,
|
||||
fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID)),
|
||||
body.Bytes(), "CompleteMultipartUpload")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check if the response contains an error (S3 can return 200 with an error body)
|
||||
var errResp ErrorResponse
|
||||
if xml.Unmarshal(respBody, &errResp) == nil && errResp.Code != "" {
|
||||
return &errResp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// abortMultipartUpload cancels an in-progress multipart upload. Called on error to clean up.
|
||||
func (c *Client) abortMultipartUpload(ctx context.Context, fullKey, uploadID string) {
|
||||
log.Tag(tagS3Client).Debug("AbortMultipartUpload key=%s uploadId=%s", fullKey, uploadID)
|
||||
reqURL := fmt.Sprintf("%s?uploadId=%s", c.objectURL(fullKey), url.QueryEscape(uploadID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
|
@ -269,7 +269,7 @@ func (m *mockS3Server) objectCount() int {
|
|||
func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
|
||||
// httptest.NewTLSServer URL is like "https://127.0.0.1:PORT"
|
||||
host := strings.TrimPrefix(server.URL, "https://")
|
||||
return &Client{
|
||||
return New(&Config{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
|
|
@ -278,7 +278,7 @@ func newTestClient(server *httptest.Server, bucket, prefix string) *Client {
|
|||
Prefix: prefix,
|
||||
PathStyle: true,
|
||||
HTTPClient: server.Client(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// --- URL parsing tests ---
|
||||
|
|
@ -363,49 +363,49 @@ func TestParseURL_EmptyBucket(t *testing.T) {
|
|||
|
||||
// --- Unit tests: URL construction ---
|
||||
|
||||
func TestClient_BucketURL_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "https://s3.example.com/my-bucket", c.bucketURL())
|
||||
func TestConfig_BucketURL_PathStyle(t *testing.T) {
|
||||
c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "https://s3.example.com/my-bucket", c.BucketURL())
|
||||
}
|
||||
|
||||
func TestClient_BucketURL_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.bucketURL())
|
||||
func TestConfig_BucketURL_VirtualHosted(t *testing.T) {
|
||||
c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com", c.BucketURL())
|
||||
}
|
||||
|
||||
func TestClient_ObjectURL_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
c := &Client{config: &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}}
|
||||
require.Equal(t, "https://s3.example.com/my-bucket/prefix/obj", c.objectURL("prefix/obj"))
|
||||
}
|
||||
|
||||
func TestClient_ObjectURL_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
c := &Client{config: &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}}
|
||||
require.Equal(t, "https://my-bucket.s3.us-east-1.amazonaws.com/prefix/obj", c.objectURL("prefix/obj"))
|
||||
}
|
||||
|
||||
func TestClient_HostHeader_PathStyle(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "s3.example.com", c.hostHeader())
|
||||
func TestConfig_HostHeader_PathStyle(t *testing.T) {
|
||||
c := &Config{Endpoint: "s3.example.com", Bucket: "my-bucket", PathStyle: true}
|
||||
require.Equal(t, "s3.example.com", c.HostHeader())
|
||||
}
|
||||
|
||||
func TestClient_HostHeader_VirtualHosted(t *testing.T) {
|
||||
c := &Client{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.hostHeader())
|
||||
func TestConfig_HostHeader_VirtualHosted(t *testing.T) {
|
||||
c := &Config{Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "my-bucket", PathStyle: false}
|
||||
require.Equal(t, "my-bucket.s3.us-east-1.amazonaws.com", c.HostHeader())
|
||||
}
|
||||
|
||||
func TestClient_ObjectKey(t *testing.T) {
|
||||
c := &Client{Prefix: "attachments"}
|
||||
c := &Client{config: &Config{Prefix: "attachments"}}
|
||||
require.Equal(t, "attachments/file123", c.objectKey("file123"))
|
||||
|
||||
c2 := &Client{Prefix: ""}
|
||||
c2 := &Client{config: &Config{Prefix: ""}}
|
||||
require.Equal(t, "file123", c2.objectKey("file123"))
|
||||
}
|
||||
|
||||
func TestClient_PrefixForList(t *testing.T) {
|
||||
c := &Client{Prefix: "attachments"}
|
||||
c := &Client{config: &Config{Prefix: "attachments"}}
|
||||
require.Equal(t, "attachments/", c.prefixForList())
|
||||
|
||||
c2 := &Client{Prefix: ""}
|
||||
c2 := &Client{config: &Config{Prefix: ""}}
|
||||
require.Equal(t, "", c2.prefixForList())
|
||||
}
|
||||
|
||||
|
|
@ -512,13 +512,13 @@ func TestClient_ListObjects(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// List with prefix client: should only see 3
|
||||
result, err := client.ListObjects(ctx, "", 0)
|
||||
result, err := client.listObjects(ctx, "", 0)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 3)
|
||||
require.False(t, result.IsTruncated)
|
||||
|
||||
// List with no-prefix client: should see all 4
|
||||
result, err = clientNoPrefix.ListObjects(ctx, "", 0)
|
||||
result, err = clientNoPrefix.listObjects(ctx, "", 0)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 4)
|
||||
}
|
||||
|
|
@ -537,20 +537,20 @@ func TestClient_ListObjects_Pagination(t *testing.T) {
|
|||
}
|
||||
|
||||
// List with max-keys=2
|
||||
result, err := client.ListObjects(ctx, "", 2)
|
||||
result, err := client.listObjects(ctx, "", 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result.Objects, 2)
|
||||
require.True(t, result.IsTruncated)
|
||||
require.NotEmpty(t, result.NextContinuationToken)
|
||||
|
||||
// Get next page
|
||||
result2, err := client.ListObjects(ctx, result.NextContinuationToken, 2)
|
||||
result2, err := client.listObjects(ctx, result.NextContinuationToken, 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result2.Objects, 2)
|
||||
require.True(t, result2.IsTruncated)
|
||||
|
||||
// Get last page
|
||||
result3, err := client.ListObjects(ctx, result2.NextContinuationToken, 2)
|
||||
result3, err := client.listObjects(ctx, result2.NextContinuationToken, 2)
|
||||
require.Nil(t, err)
|
||||
require.Len(t, result3.Objects, 1)
|
||||
require.False(t, result3.IsTruncated)
|
||||
|
|
@ -744,7 +744,7 @@ func TestClient_RealBucket(t *testing.T) {
|
|||
prefix = "ntfy-s3-test"
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
client := New(&Config{
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
Region: region,
|
||||
|
|
@ -752,7 +752,7 @@ func TestClient_RealBucket(t *testing.T) {
|
|||
Bucket: bucket,
|
||||
Prefix: prefix,
|
||||
PathStyle: pathStyle,
|
||||
}
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
|
@ -762,8 +762,7 @@ func TestClient_RealBucket(t *testing.T) {
|
|||
if len(existing) > 0 {
|
||||
keys := make([]string, len(existing))
|
||||
for i, obj := range existing {
|
||||
// Strip the prefix since DeleteObjects will re-add it
|
||||
keys[i] = strings.TrimPrefix(obj.Key, prefix+"/")
|
||||
keys[i] = obj.Key
|
||||
}
|
||||
// Batch delete in groups of 1000
|
||||
for i := 0; i < len(keys); i += 1000 {
|
||||
|
|
@ -807,7 +806,7 @@ func TestClient_RealBucket(t *testing.T) {
|
|||
|
||||
t.Run("ListObjects", func(t *testing.T) {
|
||||
// Use a sub-prefix client for isolation
|
||||
listClient := &Client{
|
||||
listClient := New(&Config{
|
||||
AccessKey: accessKey,
|
||||
SecretKey: secretKey,
|
||||
Region: region,
|
||||
|
|
@ -815,7 +814,7 @@ func TestClient_RealBucket(t *testing.T) {
|
|||
Bucket: bucket,
|
||||
Prefix: prefix + "/list-test",
|
||||
PathStyle: pathStyle,
|
||||
}
|
||||
})
|
||||
|
||||
// Put 10 objects
|
||||
for i := 0; i < 10; i++ {
|
||||
|
|
|
|||
44
s3/types.go
44
s3/types.go
|
|
@ -2,18 +2,36 @@ package s3
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds the parsed fields from an S3 URL. Use ParseURL to create one from a URL string.
|
||||
type Config struct {
|
||||
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
|
||||
PathStyle bool
|
||||
Bucket string
|
||||
Prefix string
|
||||
Region string
|
||||
AccessKey string
|
||||
SecretKey string
|
||||
Endpoint string // host[:port] only, e.g. "s3.us-east-1.amazonaws.com"
|
||||
PathStyle bool
|
||||
Bucket string
|
||||
Prefix string
|
||||
Region string
|
||||
AccessKey string
|
||||
SecretKey string
|
||||
HTTPClient *http.Client // if nil, http.DefaultClient is used
|
||||
}
|
||||
|
||||
// bucketURL returns the base URL for bucket-level operations.
|
||||
func (c *Config) BucketURL() string {
|
||||
if c.PathStyle {
|
||||
return fmt.Sprintf("https://%s/%s", c.Endpoint, c.Bucket)
|
||||
}
|
||||
return fmt.Sprintf("https://%s.%s", c.Bucket, c.Endpoint)
|
||||
}
|
||||
|
||||
// hostHeader returns the value for the Host header.
|
||||
func (c *Config) HostHeader() string {
|
||||
if c.PathStyle {
|
||||
return c.Endpoint
|
||||
}
|
||||
return c.Bucket + "." + c.Endpoint
|
||||
}
|
||||
|
||||
// Object represents an S3 object returned by list operations.
|
||||
|
|
@ -23,8 +41,8 @@ type Object struct {
|
|||
LastModified time.Time
|
||||
}
|
||||
|
||||
// ListResult holds the response from a ListObjectsV2 call.
|
||||
type ListResult struct {
|
||||
// listResult holds the response from a single ListObjectsV2 page.
|
||||
type listResult struct {
|
||||
Objects []Object
|
||||
IsTruncated bool
|
||||
NextContinuationToken string
|
||||
|
|
@ -78,10 +96,10 @@ type MultipartUpload struct {
|
|||
|
||||
// listMultipartUploadsResult is the XML response from S3 ListMultipartUploads
|
||||
type listMultipartUploadsResult struct {
|
||||
Uploads []listUpload `xml:"Upload"`
|
||||
IsTruncated bool `xml:"IsTruncated"`
|
||||
NextKeyMarker string `xml:"NextKeyMarker"`
|
||||
NextUploadIDMarker string `xml:"NextUploadIdMarker"`
|
||||
Uploads []listUpload `xml:"Upload"`
|
||||
IsTruncated bool `xml:"IsTruncated"`
|
||||
NextKeyMarker string `xml:"NextKeyMarker"`
|
||||
NextUploadIDMarker string `xml:"NextUploadIdMarker"`
|
||||
}
|
||||
|
||||
type listUpload struct {
|
||||
|
|
|
|||
|
|
@ -105,13 +105,13 @@ func TestHmacSHA256(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSignV4_SetsRequiredHeaders(t *testing.T) {
|
||||
c := &Client{
|
||||
c := &Client{config: &Config{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: "s3.us-east-1.amazonaws.com",
|
||||
Bucket: "my-bucket",
|
||||
}
|
||||
}}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
|
||||
c.signV4(req, emptyPayloadHash)
|
||||
|
|
@ -131,13 +131,13 @@ func TestSignV4_SetsRequiredHeaders(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSignV4_UnsignedPayload(t *testing.T) {
|
||||
c := &Client{
|
||||
c := &Client{config: &Config{
|
||||
AccessKey: "AKID",
|
||||
SecretKey: "SECRET",
|
||||
Region: "us-east-1",
|
||||
Endpoint: "s3.us-east-1.amazonaws.com",
|
||||
Bucket: "my-bucket",
|
||||
}
|
||||
}}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPut, "https://my-bucket.s3.us-east-1.amazonaws.com/test-key", nil)
|
||||
c.signV4(req, unsignedPayload)
|
||||
|
|
@ -146,8 +146,8 @@ func TestSignV4_UnsignedPayload(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSignV4_DifferentRegions(t *testing.T) {
|
||||
c1 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}
|
||||
c2 := &Client{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}
|
||||
c1 := &Client{config: &Config{AccessKey: "AKID", SecretKey: "SECRET", Region: "us-east-1", Endpoint: "s3.us-east-1.amazonaws.com", Bucket: "b"}}
|
||||
c2 := &Client{config: &Config{AccessKey: "AKID", SecretKey: "SECRET", Region: "eu-west-1", Endpoint: "s3.eu-west-1.amazonaws.com", Bucket: "b"}}
|
||||
|
||||
req1, _ := http.NewRequest(http.MethodGet, "https://b.s3.us-east-1.amazonaws.com/key", nil)
|
||||
c1.signV4(req1, emptyPayloadHash)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue