diff --git a/engine/apps/webhooks/models/webhook.py b/engine/apps/webhooks/models/webhook.py index 3128a39a..15cbd4bf 100644 --- a/engine/apps/webhooks/models/webhook.py +++ b/engine/apps/webhooks/models/webhook.py @@ -53,6 +53,12 @@ def generate_public_primary_key_for_webhook(): return new_public_primary_key +class WebhookSession(requests.Session): + def send(self, request, **kwargs): + parse_url(request.url) # validate URL on every redirect + return super().send(request, **kwargs) + + class WebhookQueryset(models.QuerySet): def delete(self): self.update(deleted_at=timezone.now(), name=F("name") + "_deleted_" + F("public_primary_key")) @@ -276,21 +282,15 @@ class Webhook(models.Model): raise InvalidWebhookTrigger(e.fallback_message) def make_request(self, url, request_kwargs): - if self.http_method == "GET": - r = requests.get(url, timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, **request_kwargs) - elif self.http_method == "POST": - r = requests.post(url, timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, **request_kwargs) - elif self.http_method == "PUT": - r = requests.put(url, timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, **request_kwargs) - elif self.http_method == "DELETE": - r = requests.delete(url, timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, **request_kwargs) - elif self.http_method == "OPTIONS": - r = requests.options(url, timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, **request_kwargs) - elif self.http_method == "PATCH": - r = requests.patch(url, timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, **request_kwargs) - else: + if self.http_method not in ("GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"): raise ValueError(f"Unsupported http method: {self.http_method}") - return r + + with WebhookSession() as session: + response = session.request( + self.http_method, url, timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, **request_kwargs + ) + + return response # Insight logs @property diff --git a/engine/apps/webhooks/tests/test_trigger_webhook.py b/engine/apps/webhooks/tests/test_trigger_webhook.py index 3ef949a9..381b79eb 100644 --- a/engine/apps/webhooks/tests/test_trigger_webhook.py +++ b/engine/apps/webhooks/tests/test_trigger_webhook.py @@ -11,6 +11,7 @@ from apps.alerts.models import AlertGroupExternalID, AlertGroupLogRecord, Escala from apps.base.models import UserNotificationPolicyLogRecord from apps.public_api.serializers import IncidentSerializer from apps.webhooks.models import Webhook +from apps.webhooks.models.webhook import WebhookSession from apps.webhooks.tasks import execute_webhook, send_webhook_event from apps.webhooks.tasks.trigger_webhook import NOT_FROM_SELECTED_INTEGRATION from settings.base import WEBHOOK_RESPONSE_LIMIT @@ -141,10 +142,10 @@ def test_execute_webhook_integration_filter_not_matching( ) webhook.filtered_integrations.add(other_alert_receive_channel) - with patch("apps.webhooks.models.webhook.requests") as mock_requests: + with patch("apps.webhooks.models.webhook.WebhookSession.request") as mock_request: execute_webhook(webhook.pk, alert_group.pk, None, None) - assert not mock_requests.post.called + assert not mock_request.called # no response is created for the webhook assert webhook.responses.count() == 0 # check log should exist @@ -166,10 +167,10 @@ def test_execute_webhook_integration_filter_matching( ) webhook.filtered_integrations.add(alert_receive_channel) - with patch("apps.webhooks.models.webhook.requests") as mock_requests: + with patch("apps.webhooks.models.webhook.WebhookSession.request") as mock_request: execute_webhook(webhook.pk, alert_group.pk, None, None) - assert not mock_requests.post.called + assert not mock_request.called # no response is created for the webhook assert webhook.responses.count() == 0 # check log should exist @@ -235,10 +236,13 @@ def test_execute_webhook_ok( httpretty.register_uri(httpretty.POST, templated_url, responses=[mock_response]) with patch("apps.webhooks.utils.socket.gethostbyname", return_value="8.8.8.8"): - with patch("apps.webhooks.models.webhook.requests", wraps=requests) as mock_requests: + with patch( + "apps.webhooks.models.webhook.WebhookSession.request", wraps=WebhookSession().request + ) as mock_request: execute_webhook(webhook.pk, alert_group.pk, user.pk, None) - mock_requests.post.assert_called_once_with( + mock_request.assert_called_once_with( + "POST", templated_url, timeout=TIMEOUT, headers={"some-header": alert_group.public_primary_key}, @@ -310,11 +314,10 @@ def test_execute_webhook_via_escalation_ok( mock_response = MockResponse() with patch("apps.webhooks.utils.socket.gethostbyname") as mock_gethostbyname: mock_gethostbyname.return_value = "8.8.8.8" - with patch("apps.webhooks.models.webhook.requests") as mock_requests: - mock_requests.post.return_value = mock_response + with patch("apps.webhooks.models.webhook.WebhookSession.request", return_value=mock_response) as mock_request: execute_webhook(webhook.pk, alert_group.pk, user.pk, escalation_policy.pk) - assert mock_requests.post.called + assert mock_request.called # check log record log_record = alert_group.log_records.last() assert log_record.type == AlertGroupLogRecord.TYPE_CUSTOM_WEBHOOK_TRIGGERED @@ -377,11 +380,10 @@ def test_execute_webhook_ok_forward_all( mock_response = MockResponse() with patch("apps.webhooks.utils.socket.gethostbyname") as mock_gethostbyname: mock_gethostbyname.return_value = "8.8.8.8" - with patch("apps.webhooks.models.webhook.requests") as mock_requests: - mock_requests.post.return_value = mock_response + with patch("apps.webhooks.models.webhook.WebhookSession.request", return_value=mock_response) as mock_request: execute_webhook(webhook.pk, alert_group.pk, user.pk, None, trigger_type=Webhook.TRIGGER_ACKNOWLEDGE) - assert mock_requests.post.called + assert mock_request.called expected_data = { "event": { "type": "acknowledge", @@ -423,12 +425,13 @@ def test_execute_webhook_ok_forward_all( "alert_group_resolved_by": None, } expected_call = call( + "POST", "https://something/{}/".format(alert_group.public_primary_key), timeout=TIMEOUT, headers={}, json=expected_data, ) - assert mock_requests.post.call_args == expected_call + assert mock_request.call_args == expected_call # check logs log = webhook.responses.all()[0] assert log.trigger_type == Webhook.TRIGGER_ACKNOWLEDGE @@ -485,11 +488,10 @@ def test_execute_webhook_ok_forward_all_resolved( mock_response = MockResponse() with patch("apps.webhooks.utils.socket.gethostbyname") as mock_gethostbyname: mock_gethostbyname.return_value = "8.8.8.8" - with patch("apps.webhooks.models.webhook.requests") as mock_requests: - mock_requests.post.return_value = mock_response + with patch("apps.webhooks.models.webhook.WebhookSession.request", return_value=mock_response) as mock_request: execute_webhook(webhook.pk, alert_group.pk, user.pk, None, trigger_type=Webhook.TRIGGER_RESOLVE) - assert mock_requests.post.called + assert mock_request.called expected_data = { "event": { "type": "resolve", @@ -535,12 +537,13 @@ def test_execute_webhook_ok_forward_all_resolved( }, } expected_call = call( + "POST", "https://something/{}/".format(alert_group.public_primary_key), timeout=TIMEOUT, headers={}, json=expected_data, ) - assert mock_requests.post.call_args == expected_call + assert mock_request.call_args == expected_call # check logs log = webhook.responses.all()[0] assert log.trigger_type == Webhook.TRIGGER_RESOLVE @@ -610,19 +613,19 @@ def test_execute_webhook_using_responses_data( mock_response = MockResponse() with patch("apps.webhooks.utils.socket.gethostbyname") as mock_gethostbyname: mock_gethostbyname.return_value = "8.8.8.8" - with patch("apps.webhooks.models.webhook.requests") as mock_requests: - mock_requests.post.return_value = mock_response + with patch("apps.webhooks.models.webhook.WebhookSession.request", return_value=mock_response) as mock_request: execute_webhook(webhook.pk, alert_group.pk, user.pk, None) - assert mock_requests.post.called + assert mock_request.called expected_data = {"value": "updated"} expected_call = call( + "POST", "https://something/third-party-id/", timeout=TIMEOUT, headers={}, json=expected_data, ) - assert mock_requests.post.call_args == expected_call + assert mock_request.call_args == expected_call # check logs log = webhook.responses.all()[0] assert log.status_code == 200 @@ -646,10 +649,10 @@ def test_execute_webhook_trigger_false( trigger_template="{{ integration_id == 'the-integration' }}", ) - with patch("apps.webhooks.models.webhook.requests") as mock_requests: + with patch("apps.webhooks.models.webhook.WebhookSession.request") as mock_request: execute_webhook(webhook.pk, alert_group.pk, None, None) - assert not mock_requests.post.called + assert not mock_request.called # no response is created for the webhook assert webhook.responses.count() == 0 # check log should exist @@ -709,10 +712,10 @@ def test_execute_webhook_errors( with patch("apps.webhooks.utils.socket.gethostbyname") as mock_gethostbyname: # make it a valid URL when resolving name mock_gethostbyname.return_value = "8.8.8.8" - with patch("apps.webhooks.models.webhook.requests") as mock_requests: + with patch("apps.webhooks.models.webhook.WebhookSession.request") as mock_request: execute_webhook(webhook.pk, alert_group.pk, None, None) - assert not mock_requests.post.called + assert not mock_request.called log = webhook.responses.all()[0] assert log.status_code is None assert log.content is None @@ -755,17 +758,17 @@ def test_response_content_limit( mock_response = MockResponse(content="A" * content_length) with patch("apps.webhooks.utils.socket.gethostbyname") as mock_gethostbyname: mock_gethostbyname.return_value = "8.8.8.8" - with patch("apps.webhooks.models.webhook.requests") as mock_requests: - mock_requests.post.return_value = mock_response + with patch("apps.webhooks.models.webhook.WebhookSession.request", return_value=mock_response) as mock_request: execute_webhook(webhook.pk, alert_group.pk, user.pk, None) - assert mock_requests.post.called + assert mock_request.called expected_call = call( + "POST", "https://test/", timeout=TIMEOUT, headers={}, ) - assert mock_requests.post.call_args == expected_call + assert mock_request.call_args == expected_call # check logs log = webhook.responses.all()[0] assert log.status_code == 200 @@ -774,13 +777,13 @@ def test_response_content_limit( @patch("apps.webhooks.tasks.trigger_webhook.execute_webhook", wraps=execute_webhook) -@patch("apps.webhooks.models.webhook.requests") +@patch("apps.webhooks.models.webhook.WebhookSession.request") @patch("apps.webhooks.utils.socket.gethostbyname", return_value="8.8.8.8") @pytest.mark.django_db @pytest.mark.parametrize("exception", [requests.exceptions.ConnectTimeout, requests.exceptions.ReadTimeout]) def test_manually_retried_exceptions( _mock_gethostbyname, - mock_requests, + mock_request, spy_execute_webhook, make_organization, make_user_for_organization, @@ -789,7 +792,7 @@ def test_manually_retried_exceptions( make_custom_webhook, exception, ): - mock_requests.post.side_effect = exception("foo bar") + mock_request.side_effect = exception("foo bar") organization = make_organization() user = make_user_for_organization(organization) @@ -810,12 +813,12 @@ def test_manually_retried_exceptions( # should retry execute_webhook(*execute_webhook_args) - mock_requests.post.assert_called_once_with("https://test/", timeout=TIMEOUT, headers={}) + mock_request.assert_called_once_with("POST", "https://test/", timeout=TIMEOUT, headers={}) spy_execute_webhook.apply_async.assert_called_once_with( execute_webhook_args, kwargs={"trigger_type": None, "manual_retry_num": 1}, countdown=10 ) - mock_requests.reset_mock() + mock_request.reset_mock() spy_execute_webhook.reset_mock() # should stop retrying after 3 attempts without raising issue @@ -824,16 +827,16 @@ def test_manually_retried_exceptions( except Exception: pytest.fail() - mock_requests.post.assert_called_once_with("https://test/", timeout=TIMEOUT, headers={}) + mock_request.assert_called_once_with("POST", "https://test/", timeout=TIMEOUT, headers={}) spy_execute_webhook.apply_async.assert_not_called() -@patch("apps.webhooks.models.webhook.requests.post", return_value=MockResponse()) +@patch("apps.webhooks.models.webhook.WebhookSession.request", return_value=MockResponse()) @patch("apps.webhooks.utils.socket.gethostbyname", return_value="8.8.8.8") @pytest.mark.django_db def test_execute_webhook_integration_config( _, - mock_requests_post, + mock_request, make_organization, make_user_for_organization, make_alert_receive_channel, @@ -879,14 +882,15 @@ def test_execute_webhook_integration_config( ) as mock_on_webhook_response_created: execute_webhook(webhook.pk, alert_group.pk, user.pk, None, trigger_type=Webhook.TRIGGER_ALERT_GROUP_CREATED) - assert mock_requests_post.called + assert mock_request.called # check external ID - assert mock_requests_post.call_args[0][0] == "https://something/test123" - assert mock_requests_post.call_args[1]["json"]["external_id"] == "test123" + assert mock_request.call_args[0][0] == "POST" + assert mock_request.call_args[0][1] == "https://something/test123" + assert mock_request.call_args[1]["json"]["external_id"] == "test123" # check additional webhook data - assert mock_requests_post.call_args[1]["json"]["additional_field"] == "additional_value" + assert mock_request.call_args[1]["json"]["additional_field"] == "additional_value" mock_additional_webhook_data.assert_called_once_with(source_alert_receive_channel) # check on_webhook_response_created is called diff --git a/engine/apps/webhooks/tests/test_webhook.py b/engine/apps/webhooks/tests/test_webhook.py index 91822f37..7e86dc02 100644 --- a/engine/apps/webhooks/tests/test_webhook.py +++ b/engine/apps/webhooks/tests/test_webhook.py @@ -1,5 +1,6 @@ from unittest.mock import call, patch +import httpretty import pytest from django.conf import settings from requests.auth import HTTPBasicAuth @@ -225,13 +226,11 @@ def test_check_trigger_template_ok(make_organization, make_custom_webhook): def test_make_request(make_organization, make_custom_webhook): organization = make_organization() - with patch("apps.webhooks.models.webhook.requests") as mock_requests: + with patch("apps.webhooks.models.webhook.WebhookSession.request") as mock_request: for method in ("GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"): webhook = make_custom_webhook(organization=organization, http_method=method) webhook.make_request("url", {"foo": "bar"}) - expected_call = getattr(mock_requests, method.lower()) - assert expected_call.called - assert expected_call.call_args == call("url", timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, foo="bar") + assert mock_request.call_args == call(method, "url", timeout=settings.OUTGOING_WEBHOOK_TIMEOUT, foo="bar") # invalid webhook = make_custom_webhook(organization=organization, http_method="NOT") @@ -239,6 +238,20 @@ def test_make_request(make_organization, make_custom_webhook): webhook.make_request("url", {"foo": "bar"}) +@httpretty.activate(verbose=True, allow_net_connect=False) +@pytest.mark.django_db +def test_make_request_bad_redirect(make_organization, make_custom_webhook): + organization = make_organization() + webhook = make_custom_webhook(organization=organization, http_method="POST") + + url = "http://example.com" + response = httpretty.Response(body="Redirect", status=302, location="127.0.0.1") + httpretty.register_uri(httpretty.POST, url, responses=[response]) + + with pytest.raises(InvalidWebhookUrl): + webhook.make_request(url, {}) + + @pytest.mark.django_db def test_escaping_payload_with_double_quotes(make_organization, make_custom_webhook): organization = make_organization()