diff --git a/deploy/docker/job.py b/deploy/docker/job.py index 51f49a593..5183b8084 100644 --- a/deploy/docker/job.py +++ b/deploy/docker/job.py @@ -5,6 +5,7 @@ from typing import Dict, Optional, Callable from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request +from auth import verify_token from pydantic import BaseModel, HttpUrl from api import ( @@ -30,6 +31,26 @@ def init_job_router(redis, config, token_dep) -> APIRouter: _redis, _config, _token_dep = redis, config, token_dep return router +# Auth dependency for the job endpoints. The original code used +# `Depends(lambda: _token_dep())`, which prevented FastAPI from resolving the +# Bearer credentials and raised AttributeError (HTTP 500) whenever jwt_enabled +# was true. This wrapper mirrors auth.get_token_dependency: enforce a valid +# Bearer token when JWT is on, and no-op when it is off. +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +_job_bearer = HTTPBearer(auto_error=False) + + +def _job_token_dep(credentials: Optional[HTTPAuthorizationCredentials] = Depends(_job_bearer)): + if not (_config or {}).get("security", {}).get("jwt_enabled", False): + return None + if credentials is None: + raise HTTPException( + status_code=401, + detail="Authentication required. Please provide a valid Bearer token.", + headers={"WWW-Authenticate": "Bearer"}, + ) + return verify_token(credentials) + # ---------- payload models -------------------------------------------------- class LlmJobPayload(BaseModel): @@ -56,7 +77,7 @@ async def llm_job_enqueue( payload: LlmJobPayload, background_tasks: BackgroundTasks, request: Request, - _td: Dict = Depends(lambda: _token_dep()), # late-bound dep + _td: Dict = Depends(_job_token_dep), ): webhook_config = None if payload.webhook_config: @@ -87,7 +108,7 @@ async def llm_job_enqueue( async def llm_job_status( request: Request, task_id: str, - _td: Dict = Depends(lambda: _token_dep()) + _td: Dict = Depends(_job_token_dep) ): return await handle_task_status(_redis, task_id, base_url=str(request.base_url)) @@ -97,7 +118,7 @@ async def llm_job_status( async def crawl_job_enqueue( payload: CrawlJobPayload, background_tasks: BackgroundTasks, - _td: Dict = Depends(lambda: _token_dep()), + _td: Dict = Depends(_job_token_dep), ): webhook_config = None if payload.webhook_config: @@ -123,6 +144,6 @@ async def crawl_job_enqueue( async def crawl_job_status( request: Request, task_id: str, - _td: Dict = Depends(lambda: _token_dep()) + _td: Dict = Depends(_job_token_dep) ): return await handle_task_status(_redis, task_id, base_url=str(request.base_url))