Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
}
MIDDLEWARE = [
"core.middleware.query_params.NullCharacterQueryParamMiddleware",
"common.core.middleware.APIResponseVersionHeaderMiddleware",
"common.gunicorn.middleware.RouteLoggerMiddleware",
"django.middleware.security.SecurityMiddleware",
Expand Down
29 changes: 29 additions & 0 deletions api/core/middleware/query_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from django.http import QueryDict


class NullCharacterQueryParamMiddleware:
"""
Strip NUL (0x00) characters from query parameter values.

Prevents ValueError exceptions when query parameter values containing
null characters are passed to database queries.
"""

def __init__(self, get_response): # type: ignore[no-untyped-def]
self.get_response = get_response

def __call__(self, request): # type: ignore[no-untyped-def]
if "\x00" in request.META.get("QUERY_STRING", ""):
sanitized = QueryDict(mutable=True)
for key, values in request.GET.lists():
sanitized_key = key.replace("\x00", "")
sanitized.setlist(
sanitized_key,
[v.replace("\x00", "") for v in values],
)
request.GET = sanitized
request.META["QUERY_STRING"] = request.META["QUERY_STRING"].replace(
"\x00", ""
)

return self.get_response(request)
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from django.http import HttpResponse, QueryDict

from core.middleware.query_params import NullCharacterQueryParamMiddleware


def test_null_char_middleware__strips_null_from_query_param_values(mocker): # type: ignore[no-untyped-def]
# Given
mocked_get_response = mocker.MagicMock(return_value=HttpResponse())
mock_request = mocker.MagicMock()
mock_request.META = {"QUERY_STRING": "identifier=test\x00value"}
query_dict = QueryDict("identifier=test\x00value")
mock_request.GET = query_dict

middleware = NullCharacterQueryParamMiddleware(mocked_get_response) # type: ignore[no-untyped-call]

# When
middleware(mock_request)

# Then
assert mock_request.GET["identifier"] == "testvalue"
assert "\x00" not in mock_request.META["QUERY_STRING"]


def test_null_char_middleware__strips_null_from_query_param_keys(mocker): # type: ignore[no-untyped-def]
# Given
mocked_get_response = mocker.MagicMock(return_value=HttpResponse())
mock_request = mocker.MagicMock()
mock_request.META = {"QUERY_STRING": "ident\x00ifier=test"}
query_dict = QueryDict("ident\x00ifier=test")
mock_request.GET = query_dict

middleware = NullCharacterQueryParamMiddleware(mocked_get_response) # type: ignore[no-untyped-call]

# When
middleware(mock_request)

# Then
assert "identifier" in mock_request.GET
assert mock_request.GET["identifier"] == "test"


def test_null_char_middleware__no_null_chars__passes_through(mocker): # type: ignore[no-untyped-def]
# Given
mocked_get_response = mocker.MagicMock(return_value=HttpResponse())
mock_request = mocker.MagicMock()
mock_request.META = {"QUERY_STRING": "identifier=testvalue"}
original_get = QueryDict("identifier=testvalue")
mock_request.GET = original_get

middleware = NullCharacterQueryParamMiddleware(mocked_get_response) # type: ignore[no-untyped-call]

# When
middleware(mock_request)

# Then
assert mock_request.GET is original_get


def test_null_char_middleware__empty_query_string__passes_through(mocker): # type: ignore[no-untyped-def]
# Given
mocked_get_response = mocker.MagicMock(return_value=HttpResponse())
mock_request = mocker.MagicMock()
mock_request.META = {"QUERY_STRING": ""}
original_get = QueryDict("")
mock_request.GET = original_get

middleware = NullCharacterQueryParamMiddleware(mocked_get_response) # type: ignore[no-untyped-call]

# When
middleware(mock_request)

# Then
assert mock_request.GET is original_get
Loading