Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions deploy/docker/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import Dict, Optional, Callable
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from auth import verify_token
from pydantic import BaseModel, HttpUrl

from api import (
Expand All @@ -30,6 +31,26 @@ def init_job_router(redis, config, token_dep) -> APIRouter:
_redis, _config, _token_dep = redis, config, token_dep
return router

# Auth dependency for the job endpoints. The original code used
# `Depends(lambda: _token_dep())`, which prevented FastAPI from resolving the
# Bearer credentials and raised AttributeError (HTTP 500) whenever jwt_enabled
# was true. This wrapper mirrors auth.get_token_dependency: enforce a valid
# Bearer token when JWT is on, and no-op when it is off.
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
_job_bearer = HTTPBearer(auto_error=False)


def _job_token_dep(credentials: Optional[HTTPAuthorizationCredentials] = Depends(_job_bearer)):
if not (_config or {}).get("security", {}).get("jwt_enabled", False):
return None
if credentials is None:
raise HTTPException(
status_code=401,
detail="Authentication required. Please provide a valid Bearer token.",
headers={"WWW-Authenticate": "Bearer"},
)
return verify_token(credentials)


# ---------- payload models --------------------------------------------------
class LlmJobPayload(BaseModel):
Expand All @@ -56,7 +77,7 @@ async def llm_job_enqueue(
payload: LlmJobPayload,
background_tasks: BackgroundTasks,
request: Request,
_td: Dict = Depends(lambda: _token_dep()), # late-bound dep
_td: Dict = Depends(_job_token_dep),
):
webhook_config = None
if payload.webhook_config:
Expand Down Expand Up @@ -87,7 +108,7 @@ async def llm_job_enqueue(
async def llm_job_status(
request: Request,
task_id: str,
_td: Dict = Depends(lambda: _token_dep())
_td: Dict = Depends(_job_token_dep)
):
return await handle_task_status(_redis, task_id, base_url=str(request.base_url))

Expand All @@ -97,7 +118,7 @@ async def llm_job_status(
async def crawl_job_enqueue(
payload: CrawlJobPayload,
background_tasks: BackgroundTasks,
_td: Dict = Depends(lambda: _token_dep()),
_td: Dict = Depends(_job_token_dep),
):
webhook_config = None
if payload.webhook_config:
Expand All @@ -123,6 +144,6 @@ async def crawl_job_enqueue(
async def crawl_job_status(
request: Request,
task_id: str,
_td: Dict = Depends(lambda: _token_dep())
_td: Dict = Depends(_job_token_dep)
):
return await handle_task_status(_redis, task_id, base_url=str(request.base_url))