diff --git a/fasthtml/core.py b/fasthtml/core.py index 3b3ecfa3..34981702 100644 --- a/fasthtml/core.py +++ b/fasthtml/core.py @@ -154,16 +154,16 @@ def form2dict(form: FormData) -> dict: # %% ../nbs/api/00_core.ipynb #42c9cea0 async def parse_form(req: Request) -> FormData: - "Starlette errors on empty multipart forms, so this checks for that situation" + "Starlette errors on empty multipart/json forms, so this checks for that situation" ctype = req.headers.get("Content-Type", "") if ctype.startswith("multipart/form-data"): try: boundary = ctype.split("boundary=")[1].strip() except IndexError: raise HTTPException(400, "Invalid form-data: no boundary") if int(req.headers.get("Content-Length", "0")) <= len(boundary) + 6: return FormData() return await req.form() - await req.body() # Cache body for non-multipart request types - return await req.json() if ctype == 'application/json' else await req.form() - + body = await req.body() # Cache body for non-multipart request types + if ctype == 'application/json': return await req.json() if body else {} + return await req.form() # %% ../nbs/api/00_core.ipynb #0caedd04 async def _from_body(conn, p, data): diff --git a/nbs/api/00_core.ipynb b/nbs/api/00_core.ipynb index 8cddd761..94fd26a1 100644 --- a/nbs/api/00_core.ipynb +++ b/nbs/api/00_core.ipynb @@ -595,15 +595,16 @@ "source": [ "#| export\n", "async def parse_form(req: Request) -> FormData:\n", - " \"Starlette errors on empty multipart forms, so this checks for that situation\"\n", + " \"Starlette errors on empty multipart/json forms, so this checks for that situation\"\n", " ctype = req.headers.get(\"Content-Type\", \"\")\n", " if ctype.startswith(\"multipart/form-data\"):\n", " try: boundary = ctype.split(\"boundary=\")[1].strip()\n", " except IndexError: raise HTTPException(400, \"Invalid form-data: no boundary\")\n", " if int(req.headers.get(\"Content-Length\", \"0\")) <= len(boundary) + 6: return FormData()\n", " return await req.form()\n", - " await req.body() # Cache body for non-multipart request types\n", - " return await req.json() if ctype == 'application/json' else await req.form()\n" + " body = await req.body() # Cache body for non-multipart request types\n", + " if ctype == 'application/json': return await req.json() if body else {}\n", + " return await req.form()" ] }, { @@ -734,6 +735,23 @@ "print(response.json())" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "301e06eb", + "metadata": {}, + "outputs": [], + "source": [ + "# An empty `application/json` body yields `{}` (like the empty-multipart guard), not a 500\n", + "async def _empty_json(req): return JSONResponse(form2dict(await parse_form(req)))\n", + "client = TestClient(Starlette(routes=[Route('/', _empty_json, methods=['POST'])]))\n", + "res = client.post('/', headers={'content-type':'application/json'})\n", + "test_eq(res.status_code, 200)\n", + "test_eq(res.json(), {})\n", + "res = client.post('/', json={'a':1}) # non-empty json still parses (behavior preserved)\n", + "test_eq(res.json(), {'a':1})" + ] + }, { "cell_type": "code", "execution_count": null,