oncall-engine/engine/apps/email/validate_amazon_sns_message.py

109 lines
3.9 KiB
Python
Raw Permalink Normal View History

import logging
import re
from base64 import b64decode
from urllib.parse import urlparse
import requests
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.hashes import SHA1, SHA256
from cryptography.x509 import NameOID, load_pem_x509_certificate
from django.conf import settings
from django.core.cache import cache
logger = logging.getLogger(__name__)
HOST_PATTERN = re.compile(r"^sns\.[a-zA-Z0-9\-]{3,}\.amazonaws\.com(\.cn)?$")
REQUIRED_KEYS = (
"Message",
"MessageId",
"Timestamp",
"TopicArn",
"Type",
"Signature",
"SigningCertURL",
"SignatureVersion",
)
SIGNING_KEYS_NOTIFICATION = ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type")
SIGNING_KEYS_SUBSCRIPTION = ("Message", "MessageId", "SubscribeURL", "Timestamp", "Token", "TopicArn", "Type")
def validate_amazon_sns_message(message: dict) -> bool:
"""
Validate an AWS SNS message. Based on:
- https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html
- https://github.com/aws/aws-js-sns-message-validator/blob/a6ba4d646dc60912653357660301f3b25f94d686/index.js
- https://github.com/aws/aws-php-sns-message-validator/blob/3cee0fc1aee5538e1bd677654b09fad811061d0b/src/MessageValidator.php
"""
# Check if the message has all the required keys
if not all(key in message for key in REQUIRED_KEYS):
logger.warning("Missing required keys in the message, got: %s", message.keys())
return False
# Check TopicArn
if message["TopicArn"] != settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN:
logger.warning("Invalid TopicArn: %s", message["TopicArn"])
return False
# Construct the canonical message
if message["Type"] == "Notification":
signing_keys = SIGNING_KEYS_NOTIFICATION
elif message["Type"] in ("SubscriptionConfirmation", "UnsubscribeConfirmation"):
signing_keys = SIGNING_KEYS_SUBSCRIPTION
else:
logger.warning("Invalid message type: %s", message["Type"])
return False
canonical_message = "".join(f"{key}\n{message[key]}\n" for key in signing_keys if key in message).encode()
# Check if SigningCertURL is a valid SNS URL
signing_cert_url = message["SigningCertURL"]
parsed_url = urlparse(signing_cert_url)
if (
parsed_url.scheme != "https"
or not HOST_PATTERN.match(parsed_url.netloc)
or not parsed_url.path.endswith(".pem")
):
logger.warning("Invalid SigningCertURL: %s", signing_cert_url)
return False
# Fetch the certificate
certificate_bytes = fetch_certificate(signing_cert_url)
# Verify the certificate issuer
certificate = load_pem_x509_certificate(certificate_bytes)
if certificate.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value != "Amazon":
logger.warning("Invalid certificate issuer: %s", certificate.issuer)
return False
# Verify the signature
signature = b64decode(message["Signature"])
if message["SignatureVersion"] == "1":
hash_algorithm = SHA1()
elif message["SignatureVersion"] == "2":
hash_algorithm = SHA256()
else:
logger.warning("Invalid SignatureVersion: %s", message["SignatureVersion"])
return False
try:
certificate.public_key().verify(signature, canonical_message, PKCS1v15(), hash_algorithm)
except InvalidSignature:
logger.warning("Invalid signature")
return False
return True
def fetch_certificate(certificate_url: str) -> bytes:
cache_key = f"aws_sns_cert_{certificate_url}"
cached_certificate = cache.get(cache_key)
if cached_certificate:
return cached_certificate
response = requests.get(certificate_url, timeout=5)
response.raise_for_status()
certificate = response.content
cache.set(cache_key, certificate, timeout=60 * 60) # Cache for 1 hour
return certificate