diff --git a/task-sdk/src/airflow/sdk/execution_time/sentry/configured.py b/task-sdk/src/airflow/sdk/execution_time/sentry/configured.py index d833964621662..1586c146c536b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/sentry/configured.py +++ b/task-sdk/src/airflow/sdk/execution_time/sentry/configured.py @@ -144,10 +144,14 @@ def wrapped_run(ti: RuntimeTaskInstance, context: Context, log: Logger) -> RunRe try: self.add_tagging(context["dag_run"], ti) self.add_breadcrumbs(ti) - return run(ti, context, log) + run_return = run(ti, context, log) except Exception as e: sentry_sdk.capture_exception(e) raise + _, _, run_error = run_return + if run_error: + sentry_sdk.capture_exception(run_error) + return run_return return wrapped_run diff --git a/task-sdk/tests/task_sdk/execution_time/test_sentry.py b/task-sdk/tests/task_sdk/execution_time/test_sentry.py index c7f2daccdc417..81ceab696749d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_sentry.py +++ b/task-sdk/tests/task_sdk/execution_time/test_sentry.py @@ -21,6 +21,7 @@ import importlib import sys import types +from typing import TYPE_CHECKING from unittest import mock import pytest @@ -35,6 +36,11 @@ from tests_common.test_utils.config import conf_vars +if TYPE_CHECKING: + from structlog.typing import FilteringBoundLogger as Logger + + from airflow.sdk import Context + LOGICAL_DATE = timezone.utcnow() SCHEDULE_INTERVAL = datetime.timedelta(days=1) DATA_INTERVAL = (LOGICAL_DATE, LOGICAL_DATE + SCHEDULE_INTERVAL) @@ -121,8 +127,10 @@ def mock_sentry_sdk(self): sentry_sdk = types.ModuleType("sentry_sdk") sentry_sdk.init = mock.MagicMock() sentry_sdk.integrations = mock.Mock(logging=sentry_sdk_integrations_logging) + sentry_sdk.new_scope = mock.MagicMock() sentry_sdk.configure_scope = mock.MagicMock() sentry_sdk.add_breadcrumb = mock.MagicMock() + sentry_sdk.capture_exception = mock.MagicMock() sys.modules["sentry_sdk"] = sentry_sdk sys.modules["sentry_sdk.integrations.logging"] = sentry_sdk_integrations_logging @@ -135,8 +143,10 @@ def remove_mock_sentry_sdk(self, mock_sentry_sdk): yield mock_sentry_sdk.integrations.logging.ignore_logger.reset_mock() mock_sentry_sdk.init.reset_mock() + mock_sentry_sdk.new_scope.reset_mock() mock_sentry_sdk.configure_scope.reset_mock() mock_sentry_sdk.add_breadcrumb.reset_mock() + mock_sentry_sdk.capture_exception.reset_mock() @pytest.fixture def sentry(self, mock_sentry_sdk): @@ -270,3 +280,48 @@ def test_minimum_config(self, mock_sentry_sdk, sentry_minimum): sentry_minimum.prepare_to_enrich_errors(executor_integration="") assert mock_sentry_sdk.integrations.logging.ignore_logger.mock_calls == [mock.call("airflow.task")] assert mock_sentry_sdk.init.mock_calls == [mock.call(integrations=[])] + + @pytest.mark.parametrize( + ("run_exception_return", "run_raise"), + ( + pytest.param(ValueError("This is Run Exception"), False, id="run_with_raise_exception"), + pytest.param(None, True, id="run_with_return_exception"), + pytest.param(None, False, id="run_without_exception"), + ), + ) + def test_sentry_capture_exception( + self, + mock_supervisor_comms, + sentry, + mock_sentry_sdk, + dag_run, + task_instance, + run_exception_return, + run_raise, + ): + """ + Test that sentry_sdk.capture_exception is called on error + """ + mock_supervisor_comms.send.return_value = TaskBreadcrumbsResult.model_construct( + breadcrumbs=[TASK_DATA], + ) + log = mock.Mock() + + class TestException(Exception): ... + + @sentry.enrich_errors + def mocked_run(ti: RuntimeTaskInstance, context: Context, log: Logger): + if run_raise: + raise TestException("This is Run Exception") + return STATE, None, run_exception_return + + if run_raise: + with pytest.raises(TestException): + mocked_run(task_instance, {"dag_run": dag_run}, log) + else: + mocked_run(task_instance, {"dag_run": dag_run}, log) + + if run_exception_return is not None or run_raise: + mock_sentry_sdk.capture_exception.assert_called() + else: + mock_sentry_sdk.capture_exception.assert_not_called()