diff --git a/api/app/settings/common.py b/api/app/settings/common.py index cd660dd4b3d7..b165c5745c28 100644 --- a/api/app/settings/common.py +++ b/api/app/settings/common.py @@ -333,6 +333,7 @@ "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", } MIDDLEWARE = [ + "core.middleware.query_params.NullCharacterQueryParamMiddleware", "common.core.middleware.APIResponseVersionHeaderMiddleware", "common.gunicorn.middleware.RouteLoggerMiddleware", "django.middleware.security.SecurityMiddleware", diff --git a/api/core/middleware/query_params.py b/api/core/middleware/query_params.py new file mode 100644 index 000000000000..64066c8a59c1 --- /dev/null +++ b/api/core/middleware/query_params.py @@ -0,0 +1,29 @@ +from django.http import QueryDict + + +class NullCharacterQueryParamMiddleware: + """ + Strip NUL (0x00) characters from query parameter values. + + Prevents ValueError exceptions when query parameter values containing + null characters are passed to database queries. + """ + + def __init__(self, get_response): # type: ignore[no-untyped-def] + self.get_response = get_response + + def __call__(self, request): # type: ignore[no-untyped-def] + if "\x00" in request.META.get("QUERY_STRING", ""): + sanitized = QueryDict(mutable=True) + for key, values in request.GET.lists(): + sanitized_key = key.replace("\x00", "") + sanitized.setlist( + sanitized_key, + [v.replace("\x00", "") for v in values], + ) + request.GET = sanitized + request.META["QUERY_STRING"] = request.META["QUERY_STRING"].replace( + "\x00", "" + ) + + return self.get_response(request) diff --git a/api/tests/unit/core/middleware/test_unit_core_middleware_query_params.py b/api/tests/unit/core/middleware/test_unit_core_middleware_query_params.py new file mode 100644 index 000000000000..ec751a8a7957 --- /dev/null +++ b/api/tests/unit/core/middleware/test_unit_core_middleware_query_params.py @@ -0,0 +1,73 @@ +from django.http import HttpResponse, QueryDict + +from core.middleware.query_params import NullCharacterQueryParamMiddleware + + +def test_null_char_middleware__strips_null_from_query_param_values(mocker): # type: ignore[no-untyped-def] + # Given + mocked_get_response = mocker.MagicMock(return_value=HttpResponse()) + mock_request = mocker.MagicMock() + mock_request.META = {"QUERY_STRING": "identifier=test\x00value"} + query_dict = QueryDict("identifier=test\x00value") + mock_request.GET = query_dict + + middleware = NullCharacterQueryParamMiddleware(mocked_get_response) # type: ignore[no-untyped-call] + + # When + middleware(mock_request) + + # Then + assert mock_request.GET["identifier"] == "testvalue" + assert "\x00" not in mock_request.META["QUERY_STRING"] + + +def test_null_char_middleware__strips_null_from_query_param_keys(mocker): # type: ignore[no-untyped-def] + # Given + mocked_get_response = mocker.MagicMock(return_value=HttpResponse()) + mock_request = mocker.MagicMock() + mock_request.META = {"QUERY_STRING": "ident\x00ifier=test"} + query_dict = QueryDict("ident\x00ifier=test") + mock_request.GET = query_dict + + middleware = NullCharacterQueryParamMiddleware(mocked_get_response) # type: ignore[no-untyped-call] + + # When + middleware(mock_request) + + # Then + assert "identifier" in mock_request.GET + assert mock_request.GET["identifier"] == "test" + + +def test_null_char_middleware__no_null_chars__passes_through(mocker): # type: ignore[no-untyped-def] + # Given + mocked_get_response = mocker.MagicMock(return_value=HttpResponse()) + mock_request = mocker.MagicMock() + mock_request.META = {"QUERY_STRING": "identifier=testvalue"} + original_get = QueryDict("identifier=testvalue") + mock_request.GET = original_get + + middleware = NullCharacterQueryParamMiddleware(mocked_get_response) # type: ignore[no-untyped-call] + + # When + middleware(mock_request) + + # Then + assert mock_request.GET is original_get + + +def test_null_char_middleware__empty_query_string__passes_through(mocker): # type: ignore[no-untyped-def] + # Given + mocked_get_response = mocker.MagicMock(return_value=HttpResponse()) + mock_request = mocker.MagicMock() + mock_request.META = {"QUERY_STRING": ""} + original_get = QueryDict("") + mock_request.GET = original_get + + middleware = NullCharacterQueryParamMiddleware(mocked_get_response) # type: ignore[no-untyped-call] + + # When + middleware(mock_request) + + # Then + assert mock_request.GET is original_get