diff --git a/requirements.txt b/requirements.txt index 3b6d536..0fdcbea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ dstack-sdk==0.5.3 cryptography==43.0.1 redis==5.2.1 nv-ppcie-verifier==1.5.0 +dcap-qvl>=0.3.13 diff --git a/src/app/api/v1/openai.py b/src/app/api/v1/openai.py index 5a48f1e..ce543d7 100644 --- a/src/app/api/v1/openai.py +++ b/src/app/api/v1/openai.py @@ -1,7 +1,11 @@ +import asyncio +import base64 import json import os +import time +from concurrent.futures import ThreadPoolExecutor from hashlib import sha256 -from typing import Optional +from typing import Any, Optional import httpx from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, Header, Query @@ -48,6 +52,19 @@ VLLM_COMPLETIONS_URL = f"{VLLM_BASE_URL}/v1/completions" VLLM_METRICS_URL = f"{VLLM_BASE_URL}/metrics" VLLM_MODELS_URL = f"{VLLM_BASE_URL}/v1/models" + +CHUTES_ENABLED = os.getenv("CHUTES_ENABLED", "false").lower() in ("1", "true", "yes", "on") +CHUTES_BASE_URL = os.getenv("CHUTES_BASE_URL", "https://llm.chutes.ai").rstrip("/") +CHUTES_ATTESTATION_BASE_URL = os.getenv("CHUTES_ATTESTATION_BASE_URL", "https://api.chutes.ai").rstrip("/") +CHUTES_CHAT_COMPLETIONS_URL = f"{CHUTES_BASE_URL}/v1/chat/completions" +CHUTES_MODELS_URL = f"{CHUTES_BASE_URL}/v1/models" +CHUTES_API_KEY = os.getenv("CHUTES_API_KEY") +CHUTES_CHUTE_ID_CACHE: dict[str, tuple[str, float]] = {} +CHUTES_CHUTE_ID_CACHE_TTL_SECONDS = int(os.getenv("CHUTES_CHUTE_ID_CACHE_TTL_SECONDS", "3600")) + +# Shared executor for online TDX verification to avoid per-attestation thread creation. +TDX_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.getenv("TDX_ONLINE_WORKERS", "8"))) + TIMEOUT = 60 * 10 COMMON_HEADERS = {"Content-Type": "application/json", "Accept": "application/json"} @@ -72,15 +89,30 @@ def sign_chat(text: str): ) +def _with_outbound_headers(outbound_headers: Optional[dict[str, str]] = None) -> dict[str, str]: + headers = dict(COMMON_HEADERS) + if outbound_headers: + headers.update(outbound_headers) + return headers + + +def _chutes_auth_headers() -> dict[str, str]: + if CHUTES_API_KEY: + return {"Authorization": f"Bearer {CHUTES_API_KEY}"} + return {} + + async def stream_vllm_response( url: str, request_body: bytes, modified_request_body: bytes, request_hash: Optional[str] = None, e2ee_ctx=None, + outbound_headers: Optional[dict[str, str]] = None, + model_name: Optional[str] = None, ): """ - Handle streaming vllm request + Handle streaming backend request. Args: request_body: The original request body modified_request_body: The modified enhanced request body @@ -130,15 +162,19 @@ async def generate_stream(response): # Cache the full request and response using the extracted cache key if chat_id: cache.set_chat( - chat_id, json.dumps(sign_chat(f"{request_sha256}:{response_sha256}")) + chat_id, + json.dumps(sign_chat(f"{request_sha256}:{response_sha256}")), + model_name=model_name, ) else: error_message = "Chat id could not be extracted from the response" log.error(error_message) raise Exception(error_message) - client = httpx.AsyncClient(timeout=httpx.Timeout(TIMEOUT), headers=COMMON_HEADERS) - # Forward the request to the vllm backend + client = httpx.AsyncClient( + timeout=httpx.Timeout(TIMEOUT), + headers=_with_outbound_headers(outbound_headers), + ) req = client.build_request("POST", url, content=modified_request_body) response = await client.send(req, stream=True) # If not 200, return the error response directly without streaming @@ -171,6 +207,8 @@ async def non_stream_vllm_response( modified_request_body: bytes, request_hash: Optional[str] = None, e2ee_ctx=None, + outbound_headers: Optional[dict[str, str]] = None, + model_name: Optional[str] = None, ): """ Handle non-streaming responses @@ -191,7 +229,8 @@ async def non_stream_vllm_response( log.debug(f"Calculated request hash: {request_sha256}") async with httpx.AsyncClient( - timeout=httpx.Timeout(TIMEOUT), headers=COMMON_HEADERS + timeout=httpx.Timeout(TIMEOUT), + headers=_with_outbound_headers(outbound_headers), ) as client: response = await client.post(url, content=modified_request_body) if response.status_code != 200: @@ -205,7 +244,9 @@ async def non_stream_vllm_response( if chat_id: response_sha256 = sha256(json.dumps(response_data).encode("utf-8")).hexdigest() cache.set_chat( - chat_id, json.dumps(sign_chat(f"{request_sha256}:{response_sha256}")) + chat_id, + json.dumps(sign_chat(f"{request_sha256}:{response_sha256}")), + model_name=model_name, ) else: raise Exception("Chat id could not be extracted from the response") @@ -237,6 +278,303 @@ def strip_empty_tool_calls(payload: dict) -> dict: return payload +def _normalize_signing_algo(signing_algo: str | None) -> str: + algo = ECDSA if signing_algo is None else signing_algo.strip().lower() + if algo not in [ECDSA, ED25519]: + raise ValueError("invalid_signing_algo") + return algo + + +def _build_proxy_attestation(signing_algo: str, nonce: str | None) -> dict: + context = ecdsa_context if signing_algo == ECDSA else ed25519_context + attestation = dict(generate_attestation(context, nonce)) + attestation["signing_public_key"] = local_model_public_key_hex(signing_algo) + + resp = dict(attestation) + resp["signing_public_key"] = attestation["signing_public_key"] + resp["all_attestations"] = [attestation] + return resp + + +def _error_from_upstream_429(response: httpx.Response): + retry_after = response.headers.get("Retry-After") + msg = "Upstream attestation is rate limited" + if retry_after: + msg = f"{msg}; retry after {retry_after} seconds" + return error(status_code=429, message=msg, type="upstream_rate_limited") + + +async def _resolve_chute_id(client: httpx.AsyncClient, model: str) -> str | dict: + now = time.time() + cached = CHUTES_CHUTE_ID_CACHE.get(model) + if cached and now - cached[1] < CHUTES_CHUTE_ID_CACHE_TTL_SECONDS: + return cached[0] + + resp = await client.get( + f"{CHUTES_ATTESTATION_BASE_URL}/chutes/", + params={"include_public": "true", "name": model}, + ) + if resp.status_code == 429: + return _error_from_upstream_429(resp) + if resp.status_code != 200: + return error( + status_code=502, + message=f"Failed to lookup chute by name: {resp.status_code} {resp.text}", + type="upstream_http_error", + ) + + data = resp.json() + items = data.get("items") or [] + if not items: + return error(status_code=404, message=f"No chute found for model: {model}", type="upstream_model_not_found") + + chute_id = items[0].get("chute_id") + if not chute_id: + return error(status_code=502, message="Upstream chute lookup missing chute_id", type="upstream_invalid_response") + + CHUTES_CHUTE_ID_CACHE[model] = (chute_id, now) + return chute_id + + +async def _fetch_chutes_attestation(client: httpx.AsyncClient, model: str, nonce: str) -> tuple[dict | None, dict | None]: + chute_id_or_error = await _resolve_chute_id(client, model) + if isinstance(chute_id_or_error, dict) and chute_id_or_error.get("error"): + return None, chute_id_or_error + chute_id = chute_id_or_error + + e2e_resp = await client.get(f"{CHUTES_ATTESTATION_BASE_URL}/e2e/instances/{chute_id}") + if e2e_resp.status_code == 429: + return None, _error_from_upstream_429(e2e_resp) + if e2e_resp.status_code != 200: + return None, error( + status_code=502, + message=f"Failed to fetch E2E public keys: {e2e_resp.status_code} {e2e_resp.text}", + type="upstream_http_error", + ) + + e2e_data = e2e_resp.json() + instances = e2e_data.get("instances") or [] + pubkeys = {i.get("instance_id"): i.get("e2e_pubkey") for i in instances if i.get("instance_id")} + + evidence_resp = await client.get( + f"{CHUTES_ATTESTATION_BASE_URL}/chutes/{chute_id}/evidence", + params={"nonce": nonce}, + ) + if evidence_resp.status_code == 429: + return None, _error_from_upstream_429(evidence_resp) + if evidence_resp.status_code != 200: + return None, error( + status_code=502, + message=f"Failed to fetch evidence: {evidence_resp.status_code} {evidence_resp.text}", + type="upstream_http_error", + ) + + evidence_data = evidence_resp.json() + evidence_list = evidence_data.get("evidence") or [] + + all_attestations = [] + for e in evidence_list: + iid = e.get("instance_id") + e2e_pubkey = pubkeys.get(iid) + if not iid or not e2e_pubkey: + continue + all_attestations.append( + { + "instance_id": iid, + "nonce": nonce, + "e2e_pubkey": e2e_pubkey, + "intel_quote": e.get("quote"), + "gpu_evidence": e.get("gpu_evidence", []), + "gpu_tokens": e.get("gpu_tokens"), + "tdx_verification": e.get("tdx_verification"), + "certificate": e.get("certificate"), + } + ) + + if not all_attestations: + return None, error(status_code=502, message="No usable upstream attestations returned", type="upstream_invalid_response") + + return { + "attestation_type": "chutes", + "nonce": nonce, + "chute_id": chute_id, + "all_attestations": all_attestations, + }, None + + +def _decode_quote(quote_b64: str) -> bytes: + return base64.b64decode(quote_b64, validate=True) + + +def _extract_td_attributes(quote_bytes: bytes) -> int: + body = quote_bytes[48 : 48 + 584] + td_attributes_hex = body[120:128].hex() + return int(td_attributes_hex, 16) + + +def _extract_report_data_sha256(quote_bytes: bytes) -> str: + td_report_bytes = quote_bytes[48:632] + report_data_hex = td_report_bytes[520:584].hex().lower() + return report_data_hex[:64] + + +def _decode_jwt_payload_without_verification(token: str) -> dict[str, Any]: + parts = token.split(".") + if len(parts) < 2: + raise ValueError("invalid_jwt_format") + payload = parts[1] + payload += "=" * (-len(payload) % 4) + decoded = base64.urlsafe_b64decode(payload.encode("utf-8")).decode("utf-8") + return json.loads(decoded) + + +def _extract_gpu_tokens(attestation: dict[str, Any]) -> Any: + if "gpu_tokens" in attestation: + return attestation.get("gpu_tokens") + return attestation.get("gpu_evidence") + + +def _verify_tdx_online(quote: str | bytes) -> dict[str, Any]: + """Run online TDX verification via dcap_qvl synchronously. + + Accepts either base64-encoded quote (str) or raw bytes. + """ + try: + import dcap_qvl + + quote_bytes = _decode_quote(quote) if isinstance(quote, str) else quote + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + verified_report = loop.run_until_complete(dcap_qvl.get_collateral_and_verify(quote_bytes)) + finally: + loop.close() + + result = json.loads(verified_report.to_json()) + return {"result": result, "error": None} + except Exception as exc: + return {"result": None, "error": str(exc)} + + +async def _verify_tdx_online_async(quote: str | bytes) -> dict[str, Any]: + """Async wrapper that runs TDX verification in the shared executor. + + This keeps the event loop responsive while using a bounded thread pool + for the heavy verification work. + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor(TDX_EXECUTOR, _verify_tdx_online, quote) + + +async def _verify_single_chutes_attestation(attestation: dict[str, Any], nonce: str) -> dict[str, Any]: + errors: list[str] = [] + detail: dict[str, Any] = { + "instance_id": attestation.get("instance_id"), + "tdx_status": "missing", + "tdx_error_present": False, + "debug_mode_disabled": False, + "binding_verified": False, + "gpu_token_checked": False, + "gpu_verified": None, + } + + quote_b64 = attestation.get("intel_quote") or attestation.get("quote") + e2e_pubkey = attestation.get("e2e_pubkey") + if not quote_b64: + detail["errors"] = ["missing_intel_quote"] + return detail + if not e2e_pubkey: + detail["errors"] = ["missing_e2e_pubkey"] + return detail + + try: + quote_bytes = _decode_quote(quote_b64) + except Exception: + detail["errors"] = ["invalid_quote_base64"] + return detail + + tdx_verification = await _verify_tdx_online_async(quote_bytes) + tdx_error = tdx_verification.get("error") + if tdx_error: + detail["tdx_error_present"] = True + errors.append("tdx_online_verification_error") + + tdx_result = tdx_verification.get("result") + if not tdx_result: + if not tdx_error: + errors.append("tdx_status_missing") + else: + tdx_status = tdx_result.get("status") or "missing" + detail["tdx_status"] = tdx_status + if tdx_status != "UpToDate": + errors.append(f"tdx_status_not_uptodate:{tdx_status}") + + try: + td_attributes = _extract_td_attributes(quote_bytes) + detail["debug_mode_disabled"] = not (td_attributes & 1) + if td_attributes & 1: + errors.append("tdx_debug_mode_enabled") + except Exception: + errors.append("tdx_attributes_parse_failed") + + expected_report_data = sha256((nonce + e2e_pubkey).encode("utf-8")).hexdigest().lower() + actual_report_data = _extract_report_data_sha256(quote_bytes) + detail["expected_report_data"] = expected_report_data + detail["actual_report_data"] = actual_report_data + if actual_report_data != expected_report_data: + errors.append("report_data_binding_mismatch") + else: + detail["binding_verified"] = True + + gpu_tokens = _extract_gpu_tokens(attestation) + if isinstance(gpu_tokens, dict): + if gpu_tokens.get("error"): + detail["gpu_token_checked"] = True + detail["gpu_verified"] = False + errors.append("gpu_tokens_error") + tokens = gpu_tokens.get("tokens") + if tokens: + detail["gpu_token_checked"] = True + try: + platform_entry = tokens[0] + if not isinstance(platform_entry, list) or len(platform_entry) < 2: + detail["gpu_verified"] = False + errors.append("gpu_platform_token_format_invalid") + else: + platform_claims = _decode_jwt_payload_without_verification(platform_entry[1]) + overall_ok = platform_claims.get("x-nvidia-overall-att-result") is True + nonce_ok = platform_claims.get("eat_nonce") == expected_report_data + detail["gpu_verified"] = overall_ok and nonce_ok + if not overall_ok: + errors.append("gpu_overall_attestation_failed") + if not nonce_ok: + errors.append("gpu_eat_nonce_mismatch") + except Exception: + detail["gpu_verified"] = False + errors.append("gpu_tokens_parse_failed") + + detail["errors"] = errors + return detail + + +async def _verify_chutes_attestation_bundle(attestation_bundle: dict[str, Any], nonce: str) -> tuple[bool, list[dict[str, Any]]]: + details: list[dict[str, Any]] = [] + attestations = attestation_bundle.get("all_attestations") or [] + if not attestations: + return False, [{"instance_id": None, "errors": ["missing_all_attestations"]}] + + # Verify all instances concurrently, bounded by TDX_EXECUTOR size. + tasks = [ + _verify_single_chutes_attestation(att, nonce) + for att in attestations + ] + details = await asyncio.gather(*tasks) + + all_ok = all(not d.get("errors") for d in details) + return all_ok, details + + # Get attestation report of intel quote and nvidia payload @router.get("/attestation/report", dependencies=[Depends(verify_authorization_header)]) async def attestation_report( @@ -245,35 +583,173 @@ async def attestation_report( nonce: str | None = Query(None), signing_address: str | None = Query(None), ): - signing_algo = ECDSA if signing_algo is None else signing_algo - if signing_algo not in [ECDSA, ED25519]: + try: + algo = _normalize_signing_algo(signing_algo) + except ValueError: return invalid_signing_algo() - context = ecdsa_context if signing_algo == ECDSA else ed25519_context + try: + return _build_proxy_attestation(algo, nonce) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.get("/attestation/chain", dependencies=[Depends(verify_authorization_header)]) +async def attestation_chain( + request: Request, + model: str = Query(...), + nonce: str = Query(...), + signing_algo: str | None = None, + verify_mode: str = Query("proxy"), +): + if not CHUTES_ENABLED: + return error(status_code=503, message="Chutes route is disabled", type="chutes_disabled") + + if not CHUTES_API_KEY: + return error(status_code=503, message="CHUTES_API_KEY is not configured", type="chutes_misconfigured") try: - attestation = dict(generate_attestation(context, nonce)) + algo = _normalize_signing_algo(signing_algo) + except ValueError: + return invalid_signing_algo() + + nonce = nonce.strip() + model = model.strip() + mode = verify_mode.strip().lower() + if mode not in {"proxy", "passthrough"}: + return error(status_code=400, message="verify_mode must be one of: proxy, passthrough", type="invalid_verify_mode") + + if len(nonce) < 16: + return error( + status_code=400, + message="nonce must be at least 16 characters", + type="invalid_nonce", + ) + if not model: + return error(status_code=400, message="model must not be empty", type="invalid_model") + + try: + proxy_attestation = _build_proxy_attestation(algo, nonce) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) - attestation["signing_public_key"] = local_model_public_key_hex(signing_algo) - resp = dict(attestation) - resp["signing_public_key"] = attestation["signing_public_key"] - resp["all_attestations"] = [attestation] - return resp + try: + async with httpx.AsyncClient( + timeout=httpx.Timeout(TIMEOUT), + headers={"Authorization": f"Bearer {CHUTES_API_KEY}"}, + ) as client: + upstream_attestation, upstream_error = await _fetch_chutes_attestation(client, model, nonce) + except httpx.RequestError as exc: + return error(status_code=502, message=f"Failed to fetch upstream attestation: {exc}", type="upstream_unreachable") + + if upstream_error is not None: + return upstream_error + + upstream_raw = json.dumps(upstream_attestation, sort_keys=True, separators=(",", ":")) + upstream_attestation_sha256 = sha256(upstream_raw.encode("utf-8")).hexdigest() + + context = ecdsa_context if algo == ECDSA else ed25519_context + + binding_payload = { + "nonce": nonce, + "timestamp": int(time.time()), + "provider": "chutes", + "upstream_base_url": CHUTES_ATTESTATION_BASE_URL, + "model": model, + "upstream_attestation_sha256": upstream_attestation_sha256, + } + binding_text = json.dumps(binding_payload, sort_keys=True, separators=(",", ":")) + binding_proof = { + "payload": binding_payload, + "signature": sign_message(context, binding_text), + "signing_algo": algo, + "signing_address": context.signing_address, + } + + if mode == "passthrough": + return { + "version": "1", + "verify_mode": mode, + "proxy": { + "attestation": proxy_attestation, + "signing_public_key": proxy_attestation.get("signing_public_key"), + }, + "upstream": { + "provider": "chutes", + "base_url": CHUTES_ATTESTATION_BASE_URL, + "model": model, + "attestation": upstream_attestation, + "attestation_sha256": upstream_attestation_sha256, + }, + "binding_proof": binding_proof, + } + + verified, verification_details = await _verify_chutes_attestation_bundle(upstream_attestation, nonce) + if not verified: + log.error( + "Chutes attestation verification failed in proxy mode: %s", + json.dumps(verification_details, separators=(",", ":")), + ) + return error( + status_code=502, + message="Chutes attestation verification failed in proxy mode", + type="chutes_verification_failed", + ) + total_instances = len(verification_details) + uptodate_instances = sum(1 for d in verification_details if d.get("tdx_status") == "UpToDate") + binding_verified_instances = sum(1 for d in verification_details if d.get("binding_verified") is True) + + receipt_payload = { + "nonce": nonce, + "request_hash": sha256(f"{model}:{nonce}".encode("utf-8")).hexdigest(), + "provider": "chutes", + "model": model, + "verify_mode": mode, + "verification_policy": "chutes-v1", + "verification_policy_version": "1", + "upstream_attestation_sha256": upstream_attestation_sha256, + "binding_signature": binding_proof["signature"], + "binding_signing_algo": binding_proof["signing_algo"], + "binding_signing_address": binding_proof["signing_address"], + "verified_at": int(time.time()), + "result": "pass", + "verification_summary": { + "total_instances": total_instances, + "tdx_uptodate_instances": uptodate_instances, + "binding_verified_instances": binding_verified_instances, + }, + "instance_results": verification_details, + } + receipt_text = json.dumps(receipt_payload, sort_keys=True, separators=(",", ":")) + + return { + "version": "1", + "verify_mode": mode, + "proxy": { + "attestation": proxy_attestation, + "signing_public_key": proxy_attestation.get("signing_public_key"), + }, + "verification_receipt": { + "payload": receipt_payload, + "signature": sign_message(context, receipt_text), + "signing_algo": algo, + "signing_address": context.signing_address, + }, + } -# VLLM Chat completions -@router.post("/chat/completions", dependencies=[Depends(verify_authorization_header)]) -async def chat_completions( + +async def _chat_completions_impl( request: Request, - x_request_hash: Optional[str] = Header(None, alias="X-Request-Hash"), - x_signing_algo: Optional[str] = Header(None, alias="X-Signing-Algo"), - x_client_pub_key: Optional[str] = Header(None, alias="X-Client-Pub-Key"), - x_model_pub_key: Optional[str] = Header(None, alias="X-Model-Pub-Key"), - x_e2ee_version: Optional[str] = Header(None, alias="X-E2EE-Version"), - x_e2ee_nonce: Optional[str] = Header(None, alias="X-E2EE-Nonce"), - x_e2ee_timestamp: Optional[str] = Header(None, alias="X-E2EE-Timestamp"), + x_request_hash: Optional[str], + x_signing_algo: Optional[str], + x_client_pub_key: Optional[str], + x_model_pub_key: Optional[str], + x_e2ee_version: Optional[str], + x_e2ee_nonce: Optional[str], + x_e2ee_timestamp: Optional[str], + backend_url: str, + outbound_headers: Optional[dict[str, str]] = None, ): # Keep original request body to calculate the request hash for attestation request_body = await request.body() @@ -299,25 +775,104 @@ async def chat_completions( modified_json = strip_empty_tool_calls(request_json) # Check if the request is for streaming or non-streaming - is_stream = modified_json.get( - "stream", False - ) # Default to non-streaming if not specified + is_stream = modified_json.get("stream", False) + request_model = modified_json.get("model") modified_request_body = json.dumps(modified_json).encode("utf-8") + if is_stream: - # Create a streaming response return await stream_vllm_response( - VLLM_URL, request_body, modified_request_body, x_request_hash, e2ee_ctx - ) - else: - # Handle non-streaming response - response_data = await non_stream_vllm_response( - VLLM_URL, request_body, modified_request_body, x_request_hash, e2ee_ctx - ) - return JSONResponse( - content=response_data, - headers=get_e2ee_response_headers(e2ee_ctx), + backend_url, + request_body, + modified_request_body, + x_request_hash, + e2ee_ctx, + outbound_headers=outbound_headers, + model_name=request_model, ) + response_data = await non_stream_vllm_response( + backend_url, + request_body, + modified_request_body, + x_request_hash, + e2ee_ctx, + outbound_headers=outbound_headers, + model_name=request_model, + ) + return JSONResponse( + content=response_data, + headers=get_e2ee_response_headers(e2ee_ctx), + ) + + +# Chat completions (compat route): +# - CHUTES_ENABLED=false -> original vLLM backend behavior +# - CHUTES_ENABLED=true -> transparently route to Chutes backend +@router.post("/chat/completions", dependencies=[Depends(verify_authorization_header)]) +async def chat_completions( + request: Request, + x_request_hash: Optional[str] = Header(None, alias="X-Request-Hash"), + x_signing_algo: Optional[str] = Header(None, alias="X-Signing-Algo"), + x_client_pub_key: Optional[str] = Header(None, alias="X-Client-Pub-Key"), + x_model_pub_key: Optional[str] = Header(None, alias="X-Model-Pub-Key"), + x_e2ee_version: Optional[str] = Header(None, alias="X-E2EE-Version"), + x_e2ee_nonce: Optional[str] = Header(None, alias="X-E2EE-Nonce"), + x_e2ee_timestamp: Optional[str] = Header(None, alias="X-E2EE-Timestamp"), +): + backend_url = VLLM_URL + outbound_headers = None + + if CHUTES_ENABLED: + if not CHUTES_API_KEY: + return error(status_code=503, message="CHUTES_API_KEY is not configured", type="chutes_misconfigured") + backend_url = CHUTES_CHAT_COMPLETIONS_URL + outbound_headers = _chutes_auth_headers() + + return await _chat_completions_impl( + request=request, + x_request_hash=x_request_hash, + x_signing_algo=x_signing_algo, + x_client_pub_key=x_client_pub_key, + x_model_pub_key=x_model_pub_key, + x_e2ee_version=x_e2ee_version, + x_e2ee_nonce=x_e2ee_nonce, + x_e2ee_timestamp=x_e2ee_timestamp, + backend_url=backend_url, + outbound_headers=outbound_headers, + ) + + +# Chutes chat completions (new path, side-by-side with existing logic) +@router.post("/chutes/chat/completions", dependencies=[Depends(verify_authorization_header)]) +async def chutes_chat_completions( + request: Request, + x_request_hash: Optional[str] = Header(None, alias="X-Request-Hash"), + x_signing_algo: Optional[str] = Header(None, alias="X-Signing-Algo"), + x_client_pub_key: Optional[str] = Header(None, alias="X-Client-Pub-Key"), + x_model_pub_key: Optional[str] = Header(None, alias="X-Model-Pub-Key"), + x_e2ee_version: Optional[str] = Header(None, alias="X-E2EE-Version"), + x_e2ee_nonce: Optional[str] = Header(None, alias="X-E2EE-Nonce"), + x_e2ee_timestamp: Optional[str] = Header(None, alias="X-E2EE-Timestamp"), +): + if not CHUTES_ENABLED: + return error(status_code=503, message="Chutes route is disabled", type="chutes_disabled") + + if not CHUTES_API_KEY: + return error(status_code=503, message="CHUTES_API_KEY is not configured", type="chutes_misconfigured") + + return await _chat_completions_impl( + request=request, + x_request_hash=x_request_hash, + x_signing_algo=x_signing_algo, + x_client_pub_key=x_client_pub_key, + x_model_pub_key=x_model_pub_key, + x_e2ee_version=x_e2ee_version, + x_e2ee_nonce=x_e2ee_nonce, + x_e2ee_timestamp=x_e2ee_timestamp, + backend_url=CHUTES_CHAT_COMPLETIONS_URL, + outbound_headers=_chutes_auth_headers(), + ) + # VLLM completions @router.post("/completions", dependencies=[Depends(verify_authorization_header)]) @@ -356,24 +911,33 @@ async def completions( is_stream = modified_json.get( "stream", False ) # Default to non-streaming if not specified + request_model = modified_json.get("model") modified_request_body = json.dumps(modified_json).encode("utf-8") if is_stream: # Create a streaming response return await stream_vllm_response( - VLLM_COMPLETIONS_URL, request_body, modified_request_body, x_request_hash + VLLM_COMPLETIONS_URL, + request_body, + modified_request_body, + x_request_hash, + model_name=request_model, ) else: # Handle non-streaming response response_data = await non_stream_vllm_response( - VLLM_COMPLETIONS_URL, request_body, modified_request_body, x_request_hash + VLLM_COMPLETIONS_URL, + request_body, + modified_request_body, + x_request_hash, + model_name=request_model, ) return JSONResponse(content=response_data) # Get signature for chat_id of chat history @router.get("/signature/{chat_id}", dependencies=[Depends(verify_authorization_header)]) -async def signature(request: Request, chat_id: str, signing_algo: str = None): - cache_value = cache.get_chat(chat_id) +async def signature(request: Request, chat_id: str, signing_algo: str = None, model: Optional[str] = None): + cache_value = cache.get_chat(chat_id, model_name=model) if cache_value is None: return not_found("Chat id not found or expired") @@ -410,7 +974,7 @@ async def signature(request: Request, chat_id: str, signing_algo: str = None): async def metrics(request: Request): # Get local metrics from the proxy local_metrics = get_proxy_metrics() - + # Fetch metrics from the vLLM backend try: async with httpx.AsyncClient(timeout=httpx.Timeout(TIMEOUT)) as client: @@ -423,7 +987,7 @@ async def metrics(request: Request): except Exception as e: log.error(f"Error fetching vLLM metrics: {e}") remote_metrics = f"# Error fetching vLLM metrics: {e}" - + # Combine both and return combined_metrics = f"{local_metrics}\n\n# --- vLLM Backend Metrics ---\n\n{remote_metrics}" return PlainTextResponse(combined_metrics) @@ -436,3 +1000,21 @@ async def models(request: Request): if response.status_code != 200: raise HTTPException(status_code=response.status_code, detail=response.text) return JSONResponse(content=response.json()) + + +@router.get("/chutes/models", dependencies=[Depends(verify_authorization_header)]) +async def chutes_models(request: Request): + if not CHUTES_ENABLED: + return error(status_code=503, message="Chutes route is disabled", type="chutes_disabled") + + if not CHUTES_API_KEY: + return error(status_code=503, message="CHUTES_API_KEY is not configured", type="chutes_misconfigured") + + async with httpx.AsyncClient( + timeout=httpx.Timeout(TIMEOUT), + headers=_with_outbound_headers(_chutes_auth_headers()), + ) as client: + response = await client.get(CHUTES_MODELS_URL) + if response.status_code != 200: + raise HTTPException(status_code=response.status_code, detail=response.text) + return JSONResponse(content=response.json()) diff --git a/src/app/cache/cache.py b/src/app/cache/cache.py index d629c95..5e0725d 100644 --- a/src/app/cache/cache.py +++ b/src/app/cache/cache.py @@ -1,4 +1,3 @@ -import json import os from typing import Optional @@ -9,7 +8,11 @@ CHAT_CACHE_EXPIRATION = int(os.getenv("CHAT_CACHE_EXPIRATION", "1200")) MODEL_NAME = os.getenv("MODEL_NAME") -if not MODEL_NAME: +CHUTES_ENABLED = os.getenv("CHUTES_ENABLED", "false").lower() in ("1", "true", "yes", "on") +# Stable namespace used for backward-compatible lookup when request model is dynamic. +CHUTES_FALLBACK_MODEL_NAME = os.getenv("CHUTES_FALLBACK_MODEL_NAME", "chutes") + +if not CHUTES_ENABLED and not MODEL_NAME: raise ValueError("MODEL_NAME is not set") CHAT_PREFIX = "chat" @@ -35,9 +38,23 @@ def _init_redis(self) -> Optional[RedisCache]: return None return RedisCache(expiration=CHAT_CACHE_EXPIRATION) - def _make_key(self, prefix: str, key: str) -> str: + def _resolve_model_name(self, request_model: Optional[str] = None) -> str: + """Resolve model namespace for cache key.""" + if request_model: + return request_model + + if MODEL_NAME: + return MODEL_NAME + + if CHUTES_ENABLED: + return CHUTES_FALLBACK_MODEL_NAME + + raise ValueError("MODEL_NAME is not set") + + def _make_key(self, prefix: str, key: str, model_name: Optional[str] = None) -> str: """Build namespaced cache key: model:prefix:key""" - return f"{MODEL_NAME}:{prefix}:{key}" + resolved_model_name = self._resolve_model_name(model_name) + return f"{resolved_model_name}:{prefix}:{key}" def _write_string(self, key: str, value: str) -> None: """Write string to local and optionally to Redis.""" @@ -63,16 +80,37 @@ def _read_string(self, key: str) -> Optional[str]: # Chat operations - def set_chat(self, chat_id: str, chat: str) -> None: + def set_chat(self, chat_id: str, chat: str, model_name: Optional[str] = None) -> None: """Store chat completion data.""" - key = self._make_key(CHAT_PREFIX, chat_id) + key = self._make_key(CHAT_PREFIX, chat_id, model_name=model_name) self._write_string(key, chat) - def get_chat(self, chat_id: str) -> Optional[str]: + # Backward-compatible fallback key for dynamic-model proxy mode. + if CHUTES_ENABLED and model_name: + fallback_key = self._make_key( + CHAT_PREFIX, + chat_id, + model_name=CHUTES_FALLBACK_MODEL_NAME, + ) + if fallback_key != key: + self._write_string(fallback_key, chat) + + def get_chat(self, chat_id: str, model_name: Optional[str] = None) -> Optional[str]: """Retrieve chat completion data.""" - key = self._make_key(CHAT_PREFIX, chat_id) - return self._read_string(key) - + key = self._make_key(CHAT_PREFIX, chat_id, model_name=model_name) + value = self._read_string(key) + if value: + return value + + if CHUTES_ENABLED and model_name and model_name != CHUTES_FALLBACK_MODEL_NAME: + fallback_key = self._make_key( + CHAT_PREFIX, + chat_id, + model_name=CHUTES_FALLBACK_MODEL_NAME, + ) + return self._read_string(fallback_key) + + return None cache = ChatCache() diff --git a/tests/app/test_openai_e2ee.py b/tests/app/test_openai_e2ee.py index 2b85760..441e0dd 100644 --- a/tests/app/test_openai_e2ee.py +++ b/tests/app/test_openai_e2ee.py @@ -1,3 +1,7 @@ +import base64 +import hashlib +import json + import httpx import pytest from fastapi.testclient import TestClient @@ -332,3 +336,271 @@ def test_attestation_report_includes_signing_public_key(): assert "signing_public_key" in data assert len(data["signing_public_key"]) == 64 assert data["all_attestations"][0]["signing_public_key"] == data["signing_public_key"] + + +def _make_chutes_quote_b64(nonce: str, e2e_pubkey: str, *, debug_enabled: bool = False) -> str: + quote_bytes = bytearray(700) + + td_attributes_offset = 48 + 120 + quote_bytes[td_attributes_offset : td_attributes_offset + 8] = (1 if debug_enabled else 0).to_bytes(8, "little") + + report_data_offset = 48 + 520 + report_data_hex = hashlib.sha256((nonce + e2e_pubkey).encode("utf-8")).hexdigest() + report_data_bytes = bytes.fromhex(report_data_hex) + bytes(32) + quote_bytes[report_data_offset : report_data_offset + 64] = report_data_bytes + + return base64.b64encode(bytes(quote_bytes)).decode("utf-8") + + +@pytest.mark.asyncio +@pytest.mark.respx +async def test_attestation_chain_success_proxy_mode(respx_mock): + model = "moonshotai/Kimi-K2.5-TEE" + nonce = "a" * 16 + e2e_pubkey = "pk-1" + quote_b64 = _make_chutes_quote_b64(nonce, e2e_pubkey) + + respx_mock.get("https://api.chutes.ai/chutes/").mock( + return_value=httpx.Response(200, json={"items": [{"chute_id": "chute-123"}]}) + ) + respx_mock.get("https://api.chutes.ai/e2e/instances/chute-123").mock( + return_value=httpx.Response( + 200, + json={"instances": [{"instance_id": "inst-1", "e2e_pubkey": e2e_pubkey}]}, + ) + ) + respx_mock.get("https://api.chutes.ai/chutes/chute-123/evidence").mock( + return_value=httpx.Response( + 200, + json={ + "evidence": [ + { + "instance_id": "inst-1", + "quote": quote_b64, + "tdx_verification": {"result": {"status": "UpToDate"}}, + "certificate": "cert", + } + ] + }, + ) + ) + + with patch("app.api.v1.openai.CHUTES_ENABLED", True), patch("app.api.v1.openai.CHUTES_API_KEY", "test-key"), patch( + "app.api.v1.openai._verify_tdx_online", + return_value={"result": {"status": "UpToDate"}, "error": None}, + ): + response = client.get( + "/v1/attestation/chain", + params={"model": model, "nonce": nonce, "signing_algo": "ecdsa"}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["version"] == "1" + assert data["verify_mode"] == "proxy" + assert data["proxy"]["attestation"]["request_nonce"] == nonce + assert "verification_receipt" in data + assert data["verification_receipt"]["payload"]["result"] == "pass" + assert data["verification_receipt"]["payload"]["model"] == model + summary = data["verification_receipt"]["payload"]["verification_summary"] + assert summary["total_instances"] == 1 + assert summary["binding_verified_instances"] == 1 + assert len(data["verification_receipt"]["payload"]["instance_results"]) == 1 + + +@pytest.mark.asyncio +@pytest.mark.respx +async def test_attestation_chain_passthrough_mode_returns_upstream_bundle(respx_mock): + model = "moonshotai/Kimi-K2.5-TEE" + nonce = "b" * 16 + e2e_pubkey = "pk-2" + quote_b64 = _make_chutes_quote_b64(nonce, e2e_pubkey) + + respx_mock.get("https://api.chutes.ai/chutes/").mock( + return_value=httpx.Response(200, json={"items": [{"chute_id": "chute-321"}]}) + ) + respx_mock.get("https://api.chutes.ai/e2e/instances/chute-321").mock( + return_value=httpx.Response( + 200, + json={"instances": [{"instance_id": "inst-2", "e2e_pubkey": e2e_pubkey}]}, + ) + ) + respx_mock.get("https://api.chutes.ai/chutes/chute-321/evidence").mock( + return_value=httpx.Response( + 200, + json={ + "evidence": [ + { + "instance_id": "inst-2", + "quote": quote_b64, + "tdx_verification": {"result": {"status": "UpToDate"}}, + "certificate": "cert", + } + ] + }, + ) + ) + + with patch("app.api.v1.openai.CHUTES_ENABLED", True), patch("app.api.v1.openai.CHUTES_API_KEY", "test-key"): + response = client.get( + "/v1/attestation/chain", + params={"model": model, "nonce": nonce, "signing_algo": "ecdsa", "verify_mode": "passthrough"}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["verify_mode"] == "passthrough" + expected_hash = hashlib.sha256( + json.dumps(data["upstream"]["attestation"], sort_keys=True, separators=(",", ":")).encode("utf-8") + ).hexdigest() + assert data["upstream"]["attestation_sha256"] == expected_hash + assert data["binding_proof"]["payload"]["upstream_attestation_sha256"] == expected_hash + + +def test_attestation_chain_proxy_mode_verification_failure_returns_502(): + model = "moonshotai/Kimi-K2.5-TEE" + nonce = "a" * 16 + + with patch("app.api.v1.openai.CHUTES_ENABLED", True), patch("app.api.v1.openai.CHUTES_API_KEY", "test-key"), patch( + "app.api.v1.openai._fetch_chutes_attestation", + return_value=( + { + "attestation_type": "chutes", + "nonce": nonce, + "chute_id": "chute-123", + "all_attestations": [ + { + "instance_id": "inst-1", + "e2e_pubkey": "pk-1", + "intel_quote": "bad-quote", + "tdx_verification": {"result": {"status": "OutOfDate"}}, + } + ], + }, + None, + ), + ): + response = client.get( + "/v1/attestation/chain", + params={"model": model, "nonce": nonce, "signing_algo": "ecdsa"}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert response.status_code == 502 + assert response.json()["error"]["type"] == "chutes_verification_failed" + + +def test_attestation_chain_nonce_too_short(): + with patch("app.api.v1.openai.CHUTES_ENABLED", True), patch("app.api.v1.openai.CHUTES_API_KEY", "test-key"): + response = client.get( + "/v1/attestation/chain", + params={"model": "moonshotai/Kimi-K2.5-TEE", "nonce": "short", "signing_algo": "ecdsa"}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert response.status_code == 400 + assert response.json()["error"]["type"] == "invalid_nonce" + + +@pytest.mark.asyncio +@pytest.mark.respx +async def test_attestation_chain_proxy_mode_uses_online_tdx_verification(respx_mock): + model = "moonshotai/Kimi-K2.5-TEE" + nonce = "d" * 16 + e2e_pubkey = "pk-3" + quote_b64 = _make_chutes_quote_b64(nonce, e2e_pubkey) + + respx_mock.get("https://api.chutes.ai/chutes/").mock( + return_value=httpx.Response(200, json={"items": [{"chute_id": "chute-777"}]}) + ) + respx_mock.get("https://api.chutes.ai/e2e/instances/chute-777").mock( + return_value=httpx.Response( + 200, + json={"instances": [{"instance_id": "inst-3", "e2e_pubkey": e2e_pubkey}]}, + ) + ) + respx_mock.get("https://api.chutes.ai/chutes/chute-777/evidence").mock( + return_value=httpx.Response( + 200, + json={ + "evidence": [ + { + "instance_id": "inst-3", + "quote": quote_b64, + "certificate": "cert", + } + ] + }, + ) + ) + + with patch("app.api.v1.openai.CHUTES_ENABLED", True), patch("app.api.v1.openai.CHUTES_API_KEY", "test-key"), patch( + "app.api.v1.openai._verify_tdx_online", + return_value={"result": {"status": "UpToDate"}, "error": None}, + ): + response = client.get( + "/v1/attestation/chain", + params={"model": model, "nonce": nonce, "signing_algo": "ecdsa"}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert response.status_code == 200 + assert response.json()["verification_receipt"]["payload"]["result"] == "pass" + + +def test_attestation_chain_proxy_mode_rejects_tdx_online_verification_error(): + model = "moonshotai/Kimi-K2.5-TEE" + nonce = "c" * 16 + + with patch("app.api.v1.openai.CHUTES_ENABLED", True), patch("app.api.v1.openai.CHUTES_API_KEY", "test-key"), patch( + "app.api.v1.openai._fetch_chutes_attestation", + return_value=( + { + "attestation_type": "chutes", + "nonce": nonce, + "chute_id": "chute-123", + "all_attestations": [ + { + "instance_id": "inst-1", + "e2e_pubkey": "pk-1", + "intel_quote": _make_chutes_quote_b64(nonce, "pk-1"), + } + ], + }, + None, + ), + ), patch("app.api.v1.openai._verify_tdx_online", return_value={"result": None, "error": "tdx verify failed"}): + response = client.get( + "/v1/attestation/chain", + params={"model": model, "nonce": nonce, "signing_algo": "ecdsa"}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert response.status_code == 502 + assert response.json()["error"]["type"] == "chutes_verification_failed" + + +def test_attestation_chain_invalid_verify_mode(): + with patch("app.api.v1.openai.CHUTES_ENABLED", True), patch("app.api.v1.openai.CHUTES_API_KEY", "test-key"): + response = client.get( + "/v1/attestation/chain", + params={"model": "moonshotai/Kimi-K2.5-TEE", "nonce": "a" * 16, "signing_algo": "ecdsa", "verify_mode": "bad"}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert response.status_code == 400 + assert response.json()["error"]["type"] == "invalid_verify_mode" + + +def test_attestation_chain_model_empty(): + with patch("app.api.v1.openai.CHUTES_ENABLED", True), patch("app.api.v1.openai.CHUTES_API_KEY", "test-key"): + response = client.get( + "/v1/attestation/chain", + params={"model": " ", "nonce": "a" * 16, "signing_algo": "ecdsa"}, + headers={"Authorization": TEST_AUTH_HEADER}, + ) + + assert response.status_code == 400 + assert response.json()["error"]["type"] == "invalid_model" diff --git a/tests/app/verify_attestation_chain.py b/tests/app/verify_attestation_chain.py new file mode 100644 index 0000000..a99d853 --- /dev/null +++ b/tests/app/verify_attestation_chain.py @@ -0,0 +1,182 @@ +import os +import json +import hashlib +import secrets +import requests +from requests.exceptions import ReadTimeout +from eth_account import Account +from eth_account.messages import encode_defunct +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey + +BASE_URL = os.environ.get("BASE_URL", "").rstrip("/") +API_KEY = os.environ.get("API_KEY", "") +MODEL_NAME = os.environ.get("MODEL_NAME", "") +SIGNING_ALGO = os.environ.get("SIGNING_ALGO", "ecdsa").lower() +CONNECT_TIMEOUT = int(os.environ.get("CONNECT_TIMEOUT", "15")) +READ_TIMEOUT = int(os.environ.get("READ_TIMEOUT", "300")) +MAX_RETRIES = int(os.environ.get("MAX_RETRIES", "3")) +VERIFY_MODE = os.environ.get("VERIFY_MODE", "both").lower() + + +def _canonical_json(obj) -> str: + return json.dumps(obj, sort_keys=True, separators=(",", ":")) + + +def _verify_binding_signature(binding_proof: dict): + payload_text = _canonical_json(binding_proof["payload"]) + signing_algo = binding_proof.get("signing_algo") + signature = binding_proof.get("signature") + signing_address = binding_proof.get("signing_address") + + if signing_algo == "ecdsa": + if not signature or not signature.startswith("0x"): + raise RuntimeError("invalid ecdsa signature format") + recovered = Account.recover_message( + encode_defunct(text=payload_text), + signature=signature, + ) + if recovered.lower() != (signing_address or "").lower(): + raise RuntimeError( + f"binding signature mismatch: recovered={recovered}, expected={signing_address}" + ) + return recovered + + if signing_algo == "ed25519": + # In this project ed25519 signing_address is the raw public key hex. + pubkey_hex = signing_address or "" + if len(pubkey_hex) != 64: + raise RuntimeError("invalid ed25519 signing_address/public key") + pubkey = Ed25519PublicKey.from_public_bytes(bytes.fromhex(pubkey_hex)) + pubkey.verify(bytes.fromhex(signature), payload_text.encode("utf-8")) + return pubkey_hex + + raise RuntimeError(f"unsupported signing_algo: {signing_algo}") + + +def _run_mode(verify_mode: str): + nonce = secrets.token_hex(32) + url = f"{BASE_URL}/v1/attestation/chain" + + resp = None + for attempt in range(1, MAX_RETRIES + 1): + try: + resp = requests.get( + url, + params={"model": MODEL_NAME, "nonce": nonce, "signing_algo": SIGNING_ALGO, "verify_mode": verify_mode}, + headers={"Authorization": f"Bearer {API_KEY}"}, + timeout=(CONNECT_TIMEOUT, READ_TIMEOUT), + ) + except ReadTimeout: + if attempt == MAX_RETRIES: + raise RuntimeError( + f"Read timeout after {MAX_RETRIES} attempts (connect={CONNECT_TIMEOUT}s, read={READ_TIMEOUT}s)." + ) + print(f"[{verify_mode}] attempt {attempt}/{MAX_RETRIES} timed out, retrying...") + continue + + print(f"[{verify_mode}] status:", resp.status_code) + if resp.status_code == 429: + retry_after = resp.headers.get("Retry-After") + print(f"[{verify_mode}] body:", resp.text) + if attempt == MAX_RETRIES: + raise RuntimeError( + "429 from /v1/attestation/chain (likely upstream Chutes attestation rate limit). " + f"Retry-After={retry_after}." + ) + print(f"[{verify_mode}] attempt {attempt}/{MAX_RETRIES} got 429, retrying...") + continue + + if resp.status_code >= 400: + print(f"[{verify_mode}] body:", resp.text) + resp.raise_for_status() + break + + if resp is None: + raise RuntimeError("No response received") + + data = resp.json() + assert data.get("version") == "1", "unexpected chain version" + assert data.get("verify_mode") == verify_mode, f"unexpected verify_mode: {data.get('verify_mode')}" + assert "proxy" in data, "missing proxy section" + + proxy = data["proxy"] + proxy_att = proxy["attestation"] + assert proxy.get("signing_public_key"), "missing proxy.signing_public_key" + assert proxy_att.get("signing_public_key") == proxy["signing_public_key"], "proxy signing_public_key mismatch" + + if verify_mode == "passthrough": + assert "upstream" in data and "binding_proof" in data, "missing passthrough sections" + upstream = data["upstream"] + upstream_att = upstream["attestation"] + upstream_hash = hashlib.sha256(_canonical_json(upstream_att).encode("utf-8")).hexdigest() + assert upstream.get("attestation_sha256") == upstream_hash, "upstream attestation hash mismatch" + + bp = data["binding_proof"] + payload = bp["payload"] + assert payload.get("nonce") == nonce, "binding payload nonce mismatch" + assert payload.get("model") == MODEL_NAME, "binding payload model mismatch" + assert payload.get("upstream_attestation_sha256") == upstream_hash, "binding payload hash mismatch" + assert bp.get("signing_algo") == SIGNING_ALGO, "binding signing_algo mismatch" + assert bp.get("signature"), "binding signature missing" + + recovered_signer = _verify_binding_signature(bp) + proxy_signing_address = proxy_att.get("signing_address") + if proxy_signing_address: + assert recovered_signer.lower() == proxy_signing_address.lower(), ( + f"proxy signing address mismatch: recovered={recovered_signer}, proxy={proxy_signing_address}" + ) + + print("[OK] /v1/attestation/chain passthrough validated") + print("binding_signer:", recovered_signer) + print("upstream_attestation_sha256:", upstream_hash) + else: + assert "verification_receipt" in data, "missing verification_receipt" + receipt = data["verification_receipt"] + assert receipt.get("signature"), "receipt signature missing" + payload = receipt.get("payload") or {} + assert payload.get("result") == "pass", "receipt result is not pass" + assert payload.get("model") == MODEL_NAME, "receipt model mismatch" + assert payload.get("nonce") == nonce, "receipt nonce mismatch" + + summary = payload.get("verification_summary") or {} + instance_results = payload.get("instance_results") or [] + + summary_total = summary.get("total_instances") + if summary_total is not None and instance_results: + assert summary_total == len(instance_results), ( + f"summary total_instances mismatch: summary={summary_total}, instance_results={len(instance_results)}" + ) + elif summary_total is not None and not instance_results: + print("[WARN] verification_summary exists but instance_results missing/empty. This may indicate an older server build.") + + if instance_results: + assert summary.get("binding_verified_instances", 0) >= 1, "expected at least one binding-verified instance" + + print("[OK] /v1/attestation/chain proxy mode validated") + print("verification_summary:", json.dumps(summary, ensure_ascii=False)) + print("instance_results_count:", len(instance_results)) + print("instance_results:", json.dumps(instance_results, ensure_ascii=False)) + + print("nonce:", nonce) + print("model:", MODEL_NAME) + print("signing_algo:", SIGNING_ALGO) + print("verify_mode:", verify_mode) + + +def main(): + if not BASE_URL or not API_KEY or not MODEL_NAME: + raise RuntimeError("Please set BASE_URL, API_KEY, MODEL_NAME") + + if VERIFY_MODE == "both": + _run_mode("proxy") + _run_mode("passthrough") + return + + if VERIFY_MODE not in {"proxy", "passthrough"}: + raise RuntimeError("VERIFY_MODE must be one of: proxy, passthrough, both") + + _run_mode(VERIFY_MODE) + + +if __name__ == "__main__": + main() diff --git a/tests/app/verify_attestation_chain_negative.py b/tests/app/verify_attestation_chain_negative.py new file mode 100644 index 0000000..62adc6d --- /dev/null +++ b/tests/app/verify_attestation_chain_negative.py @@ -0,0 +1,37 @@ +import os +import requests + +BASE_URL = os.environ.get("BASE_URL", "").rstrip("/") +API_KEY = os.environ.get("API_KEY", "") +MODEL_NAME = os.environ.get("MODEL_NAME", "") +SIGNING_ALGO = os.environ.get("SIGNING_ALGO", "ecdsa").lower() + + +def main(): + if not BASE_URL or not API_KEY or not MODEL_NAME: + raise RuntimeError("Please set BASE_URL, API_KEY, MODEL_NAME") + + # nonce deliberately too short, should fail with invalid_nonce + resp = requests.get( + f"{BASE_URL}/v1/attestation/chain", + params={"model": MODEL_NAME, "nonce": "short", "signing_algo": SIGNING_ALGO}, + headers={"Authorization": f"Bearer {API_KEY}"}, + timeout=60, + ) + + print("status:", resp.status_code) + print("body:", resp.text) + + if resp.status_code != 400: + raise RuntimeError(f"Expected 400 for short nonce, got {resp.status_code}") + + data = resp.json() + err_type = (data.get("error") or {}).get("type") + if err_type != "invalid_nonce": + raise RuntimeError(f"Expected error.type=invalid_nonce, got {err_type}") + + print("[OK] negative check passed (invalid_nonce)") + + +if __name__ == "__main__": + main() diff --git a/tests/app/verify_attestation_report.py b/tests/app/verify_attestation_report.py new file mode 100644 index 0000000..6c7a3b3 --- /dev/null +++ b/tests/app/verify_attestation_report.py @@ -0,0 +1,47 @@ +import os +import requests + +BASE_URL = os.environ.get("BASE_URL", "").rstrip("/") +API_KEY = os.environ.get("API_KEY", "") +MODEL_NAME = os.environ.get("MODEL_NAME", "") # kept for consistent runner env +SIGNING_ALGO = os.environ.get("SIGNING_ALGO", "ecdsa").lower() + + +def main(): + if not BASE_URL or not API_KEY: + raise RuntimeError("Please set BASE_URL and API_KEY") + + url = f"{BASE_URL}/v1/attestation/report" + resp = requests.get( + url, + params={"signing_algo": SIGNING_ALGO}, + headers={"Authorization": f"Bearer {API_KEY}"}, + timeout=60, + ) + + print("status:", resp.status_code) + resp.raise_for_status() + data = resp.json() + + assert "signing_public_key" in data, "missing signing_public_key" + assert "all_attestations" in data and isinstance(data["all_attestations"], list), "missing all_attestations" + assert data["all_attestations"], "all_attestations is empty" + assert ( + data["all_attestations"][0].get("signing_public_key") == data["signing_public_key"] + ), "top-level signing_public_key mismatch" + + expected_len = 128 if SIGNING_ALGO == "ecdsa" else 64 + assert len(data["signing_public_key"]) == expected_len, ( + f"unexpected signing_public_key length: {len(data['signing_public_key'])}, " + f"expected {expected_len} for {SIGNING_ALGO}" + ) + + print("[OK] /v1/attestation/report validated") + print("signing_algo:", SIGNING_ALGO) + print("signing_public_key len:", len(data["signing_public_key"])) + if MODEL_NAME: + print("model_name (unused in this check):", MODEL_NAME) + + +if __name__ == "__main__": + main()