# What this PR does
Minor inbound email improvements:
* Adds SNS certificate caching (the [original JS
SDK](a6ba4d646d/index.js (L101-L104))
does that as well)
* Makes sure we see a 500 when OnCall can't fetch the certificate
## Which issue(s) this PR closes
Related to https://github.com/grafana/oncall-private/issues/2905
<!--
*Note*: If you want the issue to be auto-closed once the PR is merged,
change "Related to" to "Closes" in the line above.
If you have more than one GitHub issue that this PR closes, be sure to
preface
each issue link with a [closing
keyword](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue).
This ensures that the issue(s) are auto-closed once the PR has been
merged.
-->
## Checklist
- [x] Unit, integration, and e2e (if applicable) tests updated
- [x] Documentation added (or `pr:no public docs` PR label added if not
required)
- [x] Added the relevant release notes label (see labels prefixed w/
`release:`). These labels dictate how your PR will
show up in the autogenerated release notes.
108 lines
3.9 KiB
Python
108 lines
3.9 KiB
Python
import logging
|
|
import re
|
|
from base64 import b64decode
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
from cryptography.exceptions import InvalidSignature
|
|
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
|
from cryptography.hazmat.primitives.hashes import SHA1, SHA256
|
|
from cryptography.x509 import NameOID, load_pem_x509_certificate
|
|
from django.conf import settings
|
|
from django.core.cache import cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
HOST_PATTERN = re.compile(r"^sns\.[a-zA-Z0-9\-]{3,}\.amazonaws\.com(\.cn)?$")
|
|
REQUIRED_KEYS = (
|
|
"Message",
|
|
"MessageId",
|
|
"Timestamp",
|
|
"TopicArn",
|
|
"Type",
|
|
"Signature",
|
|
"SigningCertURL",
|
|
"SignatureVersion",
|
|
)
|
|
SIGNING_KEYS_NOTIFICATION = ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type")
|
|
SIGNING_KEYS_SUBSCRIPTION = ("Message", "MessageId", "SubscribeURL", "Timestamp", "Token", "TopicArn", "Type")
|
|
|
|
|
|
def validate_amazon_sns_message(message: dict) -> bool:
|
|
"""
|
|
Validate an AWS SNS message. Based on:
|
|
- https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html
|
|
- https://github.com/aws/aws-js-sns-message-validator/blob/a6ba4d646dc60912653357660301f3b25f94d686/index.js
|
|
- https://github.com/aws/aws-php-sns-message-validator/blob/3cee0fc1aee5538e1bd677654b09fad811061d0b/src/MessageValidator.php
|
|
"""
|
|
|
|
# Check if the message has all the required keys
|
|
if not all(key in message for key in REQUIRED_KEYS):
|
|
logger.warning("Missing required keys in the message, got: %s", message.keys())
|
|
return False
|
|
|
|
# Check TopicArn
|
|
if message["TopicArn"] != settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN:
|
|
logger.warning("Invalid TopicArn: %s", message["TopicArn"])
|
|
return False
|
|
|
|
# Construct the canonical message
|
|
if message["Type"] == "Notification":
|
|
signing_keys = SIGNING_KEYS_NOTIFICATION
|
|
elif message["Type"] in ("SubscriptionConfirmation", "UnsubscribeConfirmation"):
|
|
signing_keys = SIGNING_KEYS_SUBSCRIPTION
|
|
else:
|
|
logger.warning("Invalid message type: %s", message["Type"])
|
|
return False
|
|
canonical_message = "".join(f"{key}\n{message[key]}\n" for key in signing_keys if key in message).encode()
|
|
|
|
# Check if SigningCertURL is a valid SNS URL
|
|
signing_cert_url = message["SigningCertURL"]
|
|
parsed_url = urlparse(signing_cert_url)
|
|
if (
|
|
parsed_url.scheme != "https"
|
|
or not HOST_PATTERN.match(parsed_url.netloc)
|
|
or not parsed_url.path.endswith(".pem")
|
|
):
|
|
logger.warning("Invalid SigningCertURL: %s", signing_cert_url)
|
|
return False
|
|
|
|
# Fetch the certificate
|
|
certificate_bytes = fetch_certificate(signing_cert_url)
|
|
|
|
# Verify the certificate issuer
|
|
certificate = load_pem_x509_certificate(certificate_bytes)
|
|
if certificate.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value != "Amazon":
|
|
logger.warning("Invalid certificate issuer: %s", certificate.issuer)
|
|
return False
|
|
|
|
# Verify the signature
|
|
signature = b64decode(message["Signature"])
|
|
if message["SignatureVersion"] == "1":
|
|
hash_algorithm = SHA1()
|
|
elif message["SignatureVersion"] == "2":
|
|
hash_algorithm = SHA256()
|
|
else:
|
|
logger.warning("Invalid SignatureVersion: %s", message["SignatureVersion"])
|
|
return False
|
|
try:
|
|
certificate.public_key().verify(signature, canonical_message, PKCS1v15(), hash_algorithm)
|
|
except InvalidSignature:
|
|
logger.warning("Invalid signature")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def fetch_certificate(certificate_url: str) -> bytes:
|
|
cache_key = f"aws_sns_cert_{certificate_url}"
|
|
cached_certificate = cache.get(cache_key)
|
|
if cached_certificate:
|
|
return cached_certificate
|
|
|
|
response = requests.get(certificate_url, timeout=5)
|
|
response.raise_for_status()
|
|
certificate = response.content
|
|
|
|
cache.set(cache_key, certificate, timeout=60 * 60) # Cache for 1 hour
|
|
return certificate
|