2024-11-18 09:44:32 +00:00
|
|
|
import logging
|
|
|
|
|
import re
|
|
|
|
|
from base64 import b64decode
|
|
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
from cryptography.exceptions import InvalidSignature
|
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
|
|
|
|
from cryptography.hazmat.primitives.hashes import SHA1, SHA256
|
|
|
|
|
from cryptography.x509 import NameOID, load_pem_x509_certificate
|
|
|
|
|
from django.conf import settings
|
2024-11-25 09:32:53 +00:00
|
|
|
from django.core.cache import cache
|
2024-11-18 09:44:32 +00:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
HOST_PATTERN = re.compile(r"^sns\.[a-zA-Z0-9\-]{3,}\.amazonaws\.com(\.cn)?$")
|
|
|
|
|
REQUIRED_KEYS = (
|
|
|
|
|
"Message",
|
|
|
|
|
"MessageId",
|
|
|
|
|
"Timestamp",
|
|
|
|
|
"TopicArn",
|
|
|
|
|
"Type",
|
|
|
|
|
"Signature",
|
|
|
|
|
"SigningCertURL",
|
|
|
|
|
"SignatureVersion",
|
|
|
|
|
)
|
|
|
|
|
SIGNING_KEYS_NOTIFICATION = ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type")
|
|
|
|
|
SIGNING_KEYS_SUBSCRIPTION = ("Message", "MessageId", "SubscribeURL", "Timestamp", "Token", "TopicArn", "Type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_amazon_sns_message(message: dict) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
Validate an AWS SNS message. Based on:
|
|
|
|
|
- https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html
|
|
|
|
|
- https://github.com/aws/aws-js-sns-message-validator/blob/a6ba4d646dc60912653357660301f3b25f94d686/index.js
|
|
|
|
|
- https://github.com/aws/aws-php-sns-message-validator/blob/3cee0fc1aee5538e1bd677654b09fad811061d0b/src/MessageValidator.php
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Check if the message has all the required keys
|
|
|
|
|
if not all(key in message for key in REQUIRED_KEYS):
|
|
|
|
|
logger.warning("Missing required keys in the message, got: %s", message.keys())
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Check TopicArn
|
|
|
|
|
if message["TopicArn"] != settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN:
|
|
|
|
|
logger.warning("Invalid TopicArn: %s", message["TopicArn"])
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Construct the canonical message
|
|
|
|
|
if message["Type"] == "Notification":
|
|
|
|
|
signing_keys = SIGNING_KEYS_NOTIFICATION
|
|
|
|
|
elif message["Type"] in ("SubscriptionConfirmation", "UnsubscribeConfirmation"):
|
|
|
|
|
signing_keys = SIGNING_KEYS_SUBSCRIPTION
|
|
|
|
|
else:
|
|
|
|
|
logger.warning("Invalid message type: %s", message["Type"])
|
|
|
|
|
return False
|
|
|
|
|
canonical_message = "".join(f"{key}\n{message[key]}\n" for key in signing_keys if key in message).encode()
|
|
|
|
|
|
|
|
|
|
# Check if SigningCertURL is a valid SNS URL
|
|
|
|
|
signing_cert_url = message["SigningCertURL"]
|
|
|
|
|
parsed_url = urlparse(signing_cert_url)
|
|
|
|
|
if (
|
|
|
|
|
parsed_url.scheme != "https"
|
|
|
|
|
or not HOST_PATTERN.match(parsed_url.netloc)
|
|
|
|
|
or not parsed_url.path.endswith(".pem")
|
|
|
|
|
):
|
|
|
|
|
logger.warning("Invalid SigningCertURL: %s", signing_cert_url)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Fetch the certificate
|
2024-11-25 09:32:53 +00:00
|
|
|
certificate_bytes = fetch_certificate(signing_cert_url)
|
2024-11-18 09:44:32 +00:00
|
|
|
|
|
|
|
|
# Verify the certificate issuer
|
|
|
|
|
certificate = load_pem_x509_certificate(certificate_bytes)
|
|
|
|
|
if certificate.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value != "Amazon":
|
|
|
|
|
logger.warning("Invalid certificate issuer: %s", certificate.issuer)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Verify the signature
|
|
|
|
|
signature = b64decode(message["Signature"])
|
|
|
|
|
if message["SignatureVersion"] == "1":
|
|
|
|
|
hash_algorithm = SHA1()
|
|
|
|
|
elif message["SignatureVersion"] == "2":
|
|
|
|
|
hash_algorithm = SHA256()
|
|
|
|
|
else:
|
|
|
|
|
logger.warning("Invalid SignatureVersion: %s", message["SignatureVersion"])
|
|
|
|
|
return False
|
|
|
|
|
try:
|
|
|
|
|
certificate.public_key().verify(signature, canonical_message, PKCS1v15(), hash_algorithm)
|
|
|
|
|
except InvalidSignature:
|
|
|
|
|
logger.warning("Invalid signature")
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return True
|
2024-11-25 09:32:53 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch_certificate(certificate_url: str) -> bytes:
|
|
|
|
|
cache_key = f"aws_sns_cert_{certificate_url}"
|
|
|
|
|
cached_certificate = cache.get(cache_key)
|
|
|
|
|
if cached_certificate:
|
|
|
|
|
return cached_certificate
|
|
|
|
|
|
|
|
|
|
response = requests.get(certificate_url, timeout=5)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
certificate = response.content
|
|
|
|
|
|
|
|
|
|
cache.set(cache_key, certificate, timeout=60 * 60) # Cache for 1 hour
|
|
|
|
|
return certificate
|