oncall-engine/engine/apps/email/validate_amazon_sns_message.py
Vadim Stepanov 469dccc63b
Inbound email improvements (continued) (#5294)
# What this PR does

Minor inbound email improvements:

* Adds SNS certificate caching (the [original JS
SDK](a6ba4d646d/index.js (L101-L104))
does that as well)
* Makes sure we see a 500 when OnCall can't fetch the certificate

## Which issue(s) this PR closes

Related to https://github.com/grafana/oncall-private/issues/2905

<!--
*Note*: If you want the issue to be auto-closed once the PR is merged,
change "Related to" to "Closes" in the line above.
If you have more than one GitHub issue that this PR closes, be sure to
preface
each issue link with a [closing
keyword](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue).
This ensures that the issue(s) are auto-closed once the PR has been
merged.
-->

## Checklist

- [x] Unit, integration, and e2e (if applicable) tests updated
- [x] Documentation added (or `pr:no public docs` PR label added if not
required)
- [x] Added the relevant release notes label (see labels prefixed w/
`release:`). These labels dictate how your PR will
    show up in the autogenerated release notes.
2024-11-25 09:32:53 +00:00

108 lines
3.9 KiB
Python

import logging
import re
from base64 import b64decode
from urllib.parse import urlparse
import requests
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.hashes import SHA1, SHA256
from cryptography.x509 import NameOID, load_pem_x509_certificate
from django.conf import settings
from django.core.cache import cache
logger = logging.getLogger(__name__)
HOST_PATTERN = re.compile(r"^sns\.[a-zA-Z0-9\-]{3,}\.amazonaws\.com(\.cn)?$")
REQUIRED_KEYS = (
"Message",
"MessageId",
"Timestamp",
"TopicArn",
"Type",
"Signature",
"SigningCertURL",
"SignatureVersion",
)
SIGNING_KEYS_NOTIFICATION = ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type")
SIGNING_KEYS_SUBSCRIPTION = ("Message", "MessageId", "SubscribeURL", "Timestamp", "Token", "TopicArn", "Type")
def validate_amazon_sns_message(message: dict) -> bool:
"""
Validate an AWS SNS message. Based on:
- https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html
- https://github.com/aws/aws-js-sns-message-validator/blob/a6ba4d646dc60912653357660301f3b25f94d686/index.js
- https://github.com/aws/aws-php-sns-message-validator/blob/3cee0fc1aee5538e1bd677654b09fad811061d0b/src/MessageValidator.php
"""
# Check if the message has all the required keys
if not all(key in message for key in REQUIRED_KEYS):
logger.warning("Missing required keys in the message, got: %s", message.keys())
return False
# Check TopicArn
if message["TopicArn"] != settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN:
logger.warning("Invalid TopicArn: %s", message["TopicArn"])
return False
# Construct the canonical message
if message["Type"] == "Notification":
signing_keys = SIGNING_KEYS_NOTIFICATION
elif message["Type"] in ("SubscriptionConfirmation", "UnsubscribeConfirmation"):
signing_keys = SIGNING_KEYS_SUBSCRIPTION
else:
logger.warning("Invalid message type: %s", message["Type"])
return False
canonical_message = "".join(f"{key}\n{message[key]}\n" for key in signing_keys if key in message).encode()
# Check if SigningCertURL is a valid SNS URL
signing_cert_url = message["SigningCertURL"]
parsed_url = urlparse(signing_cert_url)
if (
parsed_url.scheme != "https"
or not HOST_PATTERN.match(parsed_url.netloc)
or not parsed_url.path.endswith(".pem")
):
logger.warning("Invalid SigningCertURL: %s", signing_cert_url)
return False
# Fetch the certificate
certificate_bytes = fetch_certificate(signing_cert_url)
# Verify the certificate issuer
certificate = load_pem_x509_certificate(certificate_bytes)
if certificate.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value != "Amazon":
logger.warning("Invalid certificate issuer: %s", certificate.issuer)
return False
# Verify the signature
signature = b64decode(message["Signature"])
if message["SignatureVersion"] == "1":
hash_algorithm = SHA1()
elif message["SignatureVersion"] == "2":
hash_algorithm = SHA256()
else:
logger.warning("Invalid SignatureVersion: %s", message["SignatureVersion"])
return False
try:
certificate.public_key().verify(signature, canonical_message, PKCS1v15(), hash_algorithm)
except InvalidSignature:
logger.warning("Invalid signature")
return False
return True
def fetch_certificate(certificate_url: str) -> bytes:
cache_key = f"aws_sns_cert_{certificate_url}"
cached_certificate = cache.get(cache_key)
if cached_certificate:
return cached_certificate
response = requests.get(certificate_url, timeout=5)
response.raise_for_status()
certificate = response.content
cache.set(cache_key, certificate, timeout=60 * 60) # Cache for 1 hour
return certificate