From 8466f6130c5ae9ddb48bde47e5e0651e2c544d2f Mon Sep 17 00:00:00 2001 From: Priyank <15610225+priyankinfinnov@users.noreply.github.com> Date: Wed, 12 Nov 2025 16:28:06 +0530 Subject: [PATCH 1/7] tracer working --- .../tracing/portkey/portkey_tracing.ipynb | 152 ++++ src/openlayer/lib/__init__.py | 19 + .../lib/integrations/portkey_tracer.py | 720 ++++++++++++++++++ 3 files changed, 891 insertions(+) create mode 100644 examples/tracing/portkey/portkey_tracing.ipynb create mode 100644 src/openlayer/lib/integrations/portkey_tracer.py diff --git a/examples/tracing/portkey/portkey_tracing.ipynb b/examples/tracing/portkey/portkey_tracing.ipynb new file mode 100644 index 00000000..4231b462 --- /dev/null +++ b/examples/tracing/portkey/portkey_tracing.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openlayer-ai/openlayer-python/blob/main/examples/tracing/portkey/portkey_tracing.ipynb)\n", + "\n", + "\n", + "# Portkey monitoring quickstart\n", + "\n", + "This notebook illustrates how to get started monitoring Portkey completions with Openlayer.\n", + "\n", + "Portkey provides a unified interface to call 100+ LLM APIs using the same input/output format. This integration allows you to trace and monitor completions across all supported providers through a single interface.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install openlayer portkey-ai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Set the environment variables\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from portkey_ai import Portkey\n", + "\n", + "\n", + "# Set your Portkey API keys\n", + "os.environ['PORTKEY_API_KEY'] = \"YOUR_PORTKEY_API_HERE\"\n", + "\n", + "# Openlayer env variables\n", + "os.environ[\"OPENLAYER_API_KEY\"] = \"YOUR_OPENLAYER_API_KEY_HERE\"\n", + "os.environ[\"OPENLAYER_INFERENCE_PIPELINE_ID\"] = \"YOUR_OPENLAYER_INFERENCE_PIPELINE_ID_HERE\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Enable Portkey tracing\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from openlayer.lib import trace_portkey\n", + "\n", + "# Enable openlayer tracing for all Portkey completions\n", + "trace_portkey()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Use Portkey normally - tracing happens automatically!\n", + "\n", + "### Basic completion with OpenAI\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Basic portkey client initialization\n", + "portkey = Portkey(\n", + " api_key = os.environ['PORTKEY_API_KEY'],\n", + " config = \"YOUR_PORTKEY_CONFIG_ID_HERE\", # optional your portkey config id\n", + ")\n", + "\n", + "# Basic portkey LLM call\n", + "response = portkey.chat.completions.create(\n", + " #model = \"@YOUR_PORTKEY_SLUG/YOUR_MODEL_NAME\", # optional if giving config\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"Write a poem on Argentina, least 500 words.\"}\n", + " ]\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. View your traces\n", + "\n", + "Once you've run the examples above, you can:\n", + "\n", + "1. **Visit your OpenLayer dashboard** to see all the traced completions\n", + "2. **Analyze performance** across different models and providers\n", + "3. **Monitor costs** and token usage\n", + "4. **Debug issues** with detailed request/response logs\n", + "5. **Compare models** side-by-side\n", + "\n", + "The traces will include:\n", + "- **Request details**: Model, parameters, messages\n", + "- **Response data**: Generated content, token counts, latency\n", + "- **Provider information**: Which underlying service was used\n", + "- **Custom metadata**: Any additional context you provide\n", + "\n", + "For more information, check out:\n", + "- [OpenLayer Documentation](https://docs.openlayer.com/)\n", + "- [Portkey Documentation](https://portkey.ai/docs)\n", + "- [Portkey AI Gateway](https://portkey.ai/docs/product/ai-gateway)\n", + "- [Portkey Supported Providers](https://portkey.ai/docs/api-reference/inference-api/supported-providers)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/openlayer/lib/__init__.py b/src/openlayer/lib/__init__.py index 618f7a5e..447afd7d 100644 --- a/src/openlayer/lib/__init__.py +++ b/src/openlayer/lib/__init__.py @@ -14,6 +14,7 @@ "trace_oci_genai", "trace_oci", # Alias for backward compatibility "trace_litellm", + "trace_portkey", "update_current_trace", "update_current_step", # Offline buffer management functions @@ -188,3 +189,21 @@ def trace_litellm(): from .integrations import litellm_tracer return litellm_tracer.trace_litellm() + + +# ---------------------------------- Portkey ---------------------------------- # +def trace_portkey(): + """Enable tracing for Portkey completions. + + This function patches Portkey's chat.completions.create to automatically trace + all OpenAI-compatible completions routed via the Portkey AI Gateway. + """ + # pylint: disable=import-outside-toplevel + try: + from portkey_ai import Portkey # noqa: F401 + except ImportError: + raise ImportError("portkey-ai is required for Portkey tracing. Install with: pip install portkey-ai") + + from .integrations import portkey_tracer + + return portkey_tracer.trace_portkey() diff --git a/src/openlayer/lib/integrations/portkey_tracer.py b/src/openlayer/lib/integrations/portkey_tracer.py new file mode 100644 index 00000000..20eb75b6 --- /dev/null +++ b/src/openlayer/lib/integrations/portkey_tracer.py @@ -0,0 +1,720 @@ +"""Module with methods used to trace Portkey AI Gateway chat completions.""" + +import json +import logging +import time +from functools import wraps +from typing import Any, Dict, Iterator, Optional, Union + +try: + from portkey_ai import Portkey + HAVE_PORTKEY = True +except ImportError: + HAVE_PORTKEY = False + +from ..tracing import tracer + +logger = logging.getLogger(__name__) + + +def trace_portkey() -> None: + """Patch Portkey's chat.completions.create to trace completions. + + The following information is collected for each completion: + - start_time: The time when the completion was requested. + - end_time: The time when the completion was received. + - latency: The time it took to generate the completion. + - tokens: The total number of tokens used to generate the completion. + - prompt_tokens: The number of tokens in the prompt. + - completion_tokens: The number of tokens in the completion. + - model: The model used to generate the completion. + - model_parameters: The parameters used to configure the model. + - raw_output: The raw output of the model. + - inputs: The inputs used to generate the completion. + - Portkey-specific metadata (base URL, x-portkey-* headers if available) + + Returns + ------- + None + This function patches portkey.chat.completions.create in place. + + Example + ------- + >>> from portkey_ai import Portkey + >>> from openlayer.lib import trace_portkey + >>> + >>> trace_portkey() + >>> + >>> # Use Portkey normally - tracing happens automatically + >>> portkey = Portkey(api_key = "YOUR_PORTKEY_API_KEY") + >>> response = portkey.chat.completions.create( + ... model = "@YOUR_PROVIDER_SLUG/MODEL_NAME", + ... messages = [ + ... {"role": "system", "content": "You are a helpful assistant."}, + ... {"role": "user", "content": "What is Portkey"} + ... ], + ... inference_id="custom-id-123" # Optional Openlayer parameter + ... max_tokens = 512 + ... ) + """ + if not HAVE_PORTKEY: + raise ImportError( + "Portkey library is not installed. Please install it with: pip install portkey-ai" + ) + + # Patch instances on initialization rather than class-level attributes. + # Some SDKs initialize 'chat' lazily on the instance. + original_init = Portkey.__init__ + + @wraps(original_init) + def traced_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + try: + # Avoid double-patching + if getattr(self, "_openlayer_portkey_patched", False): + return + # Access chat to ensure it's constructed, then wrap create + chat = getattr(self, "chat", None) + if chat is None or not hasattr(chat, "completions") or not hasattr(chat.completions, "create"): + # If the structure isn't present, skip gracefully and log diagnostics + logger.debug( + "Openlayer Portkey tracer: Portkey client missing expected attributes (chat/completions/create). " + "Tracing not applied for this instance." + ) + return + original_create = chat.completions.create + + @wraps(original_create) + def traced_create(*c_args, **c_kwargs): + inference_id = c_kwargs.pop("inference_id", None) + stream = c_kwargs.get("stream", False) + if stream: + return handle_streaming_create( + self, + *c_args, + **c_kwargs, + create_func=original_create, + inference_id=inference_id, + ) + return handle_non_streaming_create( + self, + *c_args, + **c_kwargs, + create_func=original_create, + inference_id=inference_id, + ) + + self.chat.completions.create = traced_create + setattr(self, "_openlayer_portkey_patched", True) + logger.debug("Openlayer Portkey tracer: successfully patched Portkey client instance for tracing.") + except Exception as e: + logger.debug("Failed to patch Portkey client instance for tracing: %s", e) + + Portkey.__init__ = traced_init + logger.info("Openlayer Portkey tracer: tracing enabled (instance-level patch).") + + +def handle_streaming_create( + client: "Portkey", + *args, + create_func: callable, + inference_id: Optional[str] = None, + **kwargs, +) -> Iterator[Any]: + """Handles streaming chat.completions.create routed via Portkey.""" + # Portkey is OpenAI-compatible; request and chunks follow OpenAI spec + # create_func is a bound method; do not pass client again + chunks = create_func(*args, **kwargs) + return stream_chunks( + chunks=chunks, + kwargs=kwargs, + client=client, + inference_id=inference_id, + ) + + +def stream_chunks( + chunks: Iterator[Any], + kwargs: Dict[str, Any], + client: "Portkey", + inference_id: Optional[str] = None, +): + """Streams the chunks of the completion and traces the completion.""" + collected_output_data = [] + collected_function_call = {"name": "", "arguments": ""} + raw_outputs = [] + start_time = time.time() + end_time = None + first_token_time = None + num_of_completion_tokens = None + latency = None + model_name = kwargs.get("model", "unknown") + provider = "unknown" + latest_usage_data = {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None} + latest_chunk_metadata: Dict[str, Any] = {} + + try: + i = 0 + for i, chunk in enumerate(chunks): + raw_outputs.append(chunk.model_dump() if hasattr(chunk, "model_dump") else str(chunk)) + + if i == 0: + first_token_time = time.time() + # Try to detect provider at first chunk + provider = detect_provider_from_chunk(chunk, client, model_name) + if i > 0: + num_of_completion_tokens = i + 1 + + # Extract usage from chunk if available + chunk_usage = extract_usage_from_chunk(chunk) + if any(v is not None for v in chunk_usage.values()): + latest_usage_data = chunk_usage + + # Update metadata from latest chunk (headers/cost/etc.) + chunk_metadata = extract_portkey_unit_metadata(chunk, model_name) + if chunk_metadata: + latest_chunk_metadata.update(chunk_metadata) + + # Extract delta from chunk (OpenAI-compatible) + delta = get_delta_from_chunk(chunk) + + if delta and getattr(delta, "content", None): + collected_output_data.append(delta.content) + elif delta and getattr(delta, "function_call", None): + if delta.function_call.name: + collected_function_call["name"] += delta.function_call.name + if delta.function_call.arguments: + collected_function_call["arguments"] += delta.function_call.arguments + elif delta and getattr(delta, "tool_calls", None): + tool_call = delta.tool_calls[0] + if getattr(tool_call.function, "name", None): + collected_function_call["name"] += tool_call.function.name + if getattr(tool_call.function, "arguments", None): + collected_function_call["arguments"] += tool_call.function.arguments + + yield chunk + end_time = time.time() + latency = (end_time - start_time) * 1000 + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed to yield Portkey chunk. %s", e) + finally: + # Try to add step to the trace + try: + collected_output_data = [m for m in collected_output_data if m is not None] + if collected_output_data: + output_data = "".join(collected_output_data) + else: + if collected_function_call["arguments"]: + try: + collected_function_call["arguments"] = json.loads(collected_function_call["arguments"]) + except json.JSONDecodeError: + pass + output_data = collected_function_call + + # Calculate usage and cost at end of stream (prioritize actual usage if present) + completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost( + chunks=raw_outputs, + messages=kwargs.get("messages", []), + output_content=output_data, + model_name=model_name, + latest_usage_data=latest_usage_data, + latest_chunk_metadata=latest_chunk_metadata, + ) + + usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {} + final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0) + final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens) + final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", (final_prompt_tokens or 0) + (final_completion_tokens or 0)) + final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get("cost", None) + + trace_args = create_trace_args( + end_time=end_time, + inputs={"prompt": kwargs.get("messages", [])}, + output=output_data, + latency=latency, + tokens=final_total_tokens, + prompt_tokens=final_prompt_tokens, + completion_tokens=final_completion_tokens, + model=model_name, + model_parameters=get_model_parameters(kwargs), + raw_output=raw_outputs, + id=inference_id, + cost=final_cost, + metadata={ + "timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None), + "provider": provider, + "portkey_model": model_name, + **extract_portkey_metadata(client), + **latest_chunk_metadata, + }, + ) + add_to_trace(**trace_args) + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed to trace the Portkey streaming completion. %s", e) + + +def handle_non_streaming_create( + client: "Portkey", + *args, + create_func: callable, + inference_id: Optional[str] = None, + **kwargs, +) -> Any: + """Handles non-streaming chat.completions.create routed via Portkey.""" + start_time = time.time() + # create_func is a bound method; do not pass client again + response = create_func(*args, **kwargs) + end_time = time.time() + + # Try to add step to the trace + try: + output_data = parse_non_streaming_output_data(response) + + # Usage (if provided by upstream provider via Portkey) + usage_data = extract_usage_from_response(response) + model_name = getattr(response, "model", kwargs.get("model", "unknown")) + provider = detect_provider_from_response(response, client, model_name) + extra_metadata = extract_portkey_unit_metadata(response, model_name) + cost = extra_metadata.get("cost", None) + + trace_args = create_trace_args( + end_time=end_time, + inputs={"prompt": kwargs.get("messages", [])}, + output=output_data, + latency=(end_time - start_time) * 1000, + tokens=usage_data.get("total_tokens"), + prompt_tokens=usage_data.get("prompt_tokens"), + completion_tokens=usage_data.get("completion_tokens"), + model=model_name, + model_parameters=get_model_parameters(kwargs), + raw_output=response.model_dump() if hasattr(response, "model_dump") else str(response), + id=inference_id, + cost=cost, + metadata={ + "system_fingerprint": getattr(response, "system_fingerprint", None), + "provider": provider, + "portkey_model": model_name, + **extract_portkey_metadata(client), + **extra_metadata, + }, + ) + add_to_trace(**trace_args) + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed to trace the Portkey non-streaming completion. %s", e) + + return response + + +def get_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Gets the model parameters from the kwargs (OpenAI-compatible).""" + return { + "temperature": kwargs.get("temperature", 1), + "top_p": kwargs.get("top_p", 1), + "max_tokens": kwargs.get("max_tokens", None), + "n": kwargs.get("n", 1), + "stream": kwargs.get("stream", False), + "stop": kwargs.get("stop", None), + "presence_penalty": kwargs.get("presence_penalty", 0), + "frequency_penalty": kwargs.get("frequency_penalty", 0), + "logit_bias": kwargs.get("logit_bias", None), + "logprobs": kwargs.get("logprobs", False), + "top_logprobs": kwargs.get("top_logprobs", None), + "parallel_tool_calls": kwargs.get("parallel_tool_calls", True), + "seed": kwargs.get("seed", None), + "response_format": kwargs.get("response_format", None), + "timeout": kwargs.get("timeout", None), + "api_base": kwargs.get("api_base", None), + "api_version": kwargs.get("api_version", None), + } + + +def create_trace_args( + end_time: float, + inputs: Dict[str, Any], + output: Union[str, Dict[str, Any], None], + latency: float, + tokens: Optional[int], + prompt_tokens: Optional[int], + completion_tokens: Optional[int], + model: str, + model_parameters: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + raw_output: Optional[Union[str, Dict[str, Any]]] = None, + id: Optional[str] = None, + cost: Optional[float] = None, +) -> Dict[str, Any]: + """Returns a dictionary with the trace arguments.""" + trace_args = { + "end_time": end_time, + "inputs": inputs, + "output": output, + "latency": latency, + "tokens": tokens, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "model": model, + "model_parameters": model_parameters, + "raw_output": raw_output, + "metadata": metadata if metadata else {}, + } + if id: + trace_args["id"] = id + if cost is not None: + trace_args["cost"] = cost + return trace_args + + +def add_to_trace(**kwargs) -> None: + """Add a chat completion step to the trace.""" + provider = kwargs.get("metadata", {}).get("provider", "Portkey") + tracer.add_chat_completion_step_to_trace(**kwargs, name="Portkey Chat Completion", provider=provider) + + +def parse_non_streaming_output_data(response: Any) -> Union[str, Dict[str, Any], None]: + """Parses the output data from a non-streaming completion (OpenAI-compatible).""" + try: + if hasattr(response, "choices") and response.choices: + choice = response.choices[0] + message = getattr(choice, "message", None) + if message is None: + return None + content = getattr(message, "content", None) + function_call = getattr(message, "function_call", None) + tool_calls = getattr(message, "tool_calls", None) + if content: + return content.strip() + if function_call: + return { + "name": function_call.name, + "arguments": json.loads(function_call.arguments) if isinstance(function_call.arguments, str) else function_call.arguments, + } + if tool_calls: + return { + "name": tool_calls[0].function.name, + "arguments": json.loads(tool_calls[0].function.arguments) if isinstance(tool_calls[0].function.arguments, str) else tool_calls[0].function.arguments, + } + except Exception as e: + logger.debug("Error parsing Portkey output data: %s", e) + return None + + +def extract_portkey_metadata(client: "Portkey") -> Dict[str, Any]: + """Extract Portkey-specific metadata from a Portkey client. + + Attempts to read base URL and redact x-portkey-* headers if present. + Works defensively across SDK versions. + """ + metadata: Dict[str, Any] = {"isPortkey": True} + # Base URL or host + for attr in ("base_url", "baseURL", "host", "custom_host"): + try: + val = getattr(client, attr, None) + if val: + metadata["portkeyBaseUrl"] = str(val) + break + except Exception: + continue + + # Headers + possible_header_attrs = ("default_headers", "headers", "_default_headers", "_headers") + redacted: Dict[str, Any] = {} + for attr in possible_header_attrs: + try: + headers = getattr(client, attr, None) + if isinstance(headers, dict): + for k, v in headers.items(): + if isinstance(k, str) and k.lower().startswith("x-portkey-"): + if k.lower() in {"x-portkey-api-key", "x-portkey-virtual-key"}: + redacted[k] = "***" + else: + redacted[k] = v + except Exception: + continue + if redacted: + metadata["portkeyHeaders"] = redacted + else: + logger.debug( + "Openlayer Portkey tracer: No x-portkey-* headers detected on client; provider/config metadata may be limited." + ) + return metadata + + +def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]: + """Extract metadata from a response or chunk unit (headers, cost, ids).""" + metadata: Dict[str, Any] = {} + try: + # Extract system fingerprint if available (OpenAI-compatible) + if hasattr(unit, "system_fingerprint"): + metadata["system_fingerprint"] = unit.system_fingerprint + + # Response headers may be present on the object + headers_obj = None + if hasattr(unit, "_response_headers"): + headers_obj = getattr(unit, "_response_headers") + elif hasattr(unit, "response_headers"): + headers_obj = getattr(unit, "response_headers") + if isinstance(headers_obj, dict): + headers = {str(k): v for k, v in headers_obj.items()} + metadata["response_headers"] = headers + # Known Portkey header hints (names are lower-cased defensively) + lower = {k.lower(): v for k, v in headers.items()} + if "x-portkey-request-id" in lower: + metadata["portkey_request_id"] = lower["x-portkey-request-id"] + if "x-portkey-route" in lower: + metadata["portkey_route"] = lower["x-portkey-route"] + if "x-portkey-config" in lower: + metadata["portkey_config"] = lower["x-portkey-config"] + if "x-portkey-provider" in lower: + metadata["provider"] = lower["x-portkey-provider"] + # Cost if surfaced by gateway + if "x-portkey-cost" in lower: + try: + metadata["cost"] = float(lower["x-portkey-cost"]) + except Exception: + pass + except Exception: + pass + # Attach model for convenience + if model_name: + metadata["portkey_model"] = model_name + return metadata + + +def extract_usage_from_response(response: Any) -> Dict[str, Optional[int]]: + """Extract usage from a non-streaming response.""" + try: + if hasattr(response, "usage") and response.usage is not None: + usage = response.usage + return { + "total_tokens": getattr(usage, "total_tokens", None), + "prompt_tokens": getattr(usage, "prompt_tokens", None), + "completion_tokens": getattr(usage, "completion_tokens", None), + } + except Exception: + pass + return {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None} + + +def extract_usage_from_chunk(chunk: Any) -> Dict[str, Optional[int]]: + """Extract usage from a streaming chunk if present.""" + try: + # Check for usage attribute + if hasattr(chunk, "usage") and chunk.usage is not None: + usage = chunk.usage + return { + "total_tokens": getattr(usage, "total_tokens", None), + "prompt_tokens": getattr(usage, "prompt_tokens", None), + "completion_tokens": getattr(usage, "completion_tokens", None), + } + + # Check for usage in _hidden_params (if SDK stores it there) + if hasattr(chunk, "_hidden_params"): + hidden_params = chunk._hidden_params + # Check if usage is a direct attribute + if hasattr(hidden_params, "usage") and hidden_params.usage is not None: + usage = hidden_params.usage + return { + "total_tokens": getattr(usage, "total_tokens", None), + "prompt_tokens": getattr(usage, "prompt_tokens", None), + "completion_tokens": getattr(usage, "completion_tokens", None), + } + # Check if usage is a dictionary key + elif isinstance(hidden_params, dict) and "usage" in hidden_params: + usage = hidden_params["usage"] + if usage: + return { + "total_tokens": usage.get("total_tokens", None), + "prompt_tokens": usage.get("prompt_tokens", None), + "completion_tokens": usage.get("completion_tokens", None), + } + + # Check if chunk model dump has usage + if hasattr(chunk, "model_dump"): + chunk_dict = chunk.model_dump() + if isinstance(chunk_dict, dict) and "usage" in chunk_dict and chunk_dict["usage"]: + usage = chunk_dict["usage"] + return { + "total_tokens": usage.get("total_tokens", None), + "prompt_tokens": usage.get("prompt_tokens", None), + "completion_tokens": usage.get("completion_tokens", None), + } + except Exception: + pass + return {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None} + + +def calculate_streaming_usage_and_cost( + chunks: Any, + messages: Any, + output_content: Any, + model_name: str, + latest_usage_data: Dict[str, Optional[int]], + latest_chunk_metadata: Dict[str, Any], +): + """Calculate usage and cost at the end of streaming (LiteLLM-style logic).""" + try: + # Priority 1: Actual usage provided in chunks + if latest_usage_data and latest_usage_data.get("total_tokens") and latest_usage_data.get("total_tokens") > 0: + return ( + latest_usage_data.get("completion_tokens"), + latest_usage_data.get("prompt_tokens"), + latest_usage_data.get("total_tokens"), + latest_chunk_metadata.get("cost"), + ) + + # Priority 2: Look for usage embedded in final chunk dicts (if raw dicts) + if isinstance(chunks, list): + for chunk_data in reversed(chunks): + if isinstance(chunk_data, dict) and "usage" in chunk_data and chunk_data["usage"]: + usage = chunk_data["usage"] + if usage.get("total_tokens", 0) > 0: + return ( + usage.get("completion_tokens"), + usage.get("prompt_tokens"), + usage.get("total_tokens"), + latest_chunk_metadata.get("cost"), + ) + + # Priority 3: Estimate tokens + completion_tokens = None + prompt_tokens = None + total_tokens = None + cost = None + + # Estimate completion tokens + if isinstance(output_content, str): + completion_tokens = max(1, len(output_content) // 4) + elif isinstance(output_content, dict): + json_str = json.dumps(output_content) if output_content else "{}" + completion_tokens = max(1, len(json_str) // 4) + else: + # Fallback: count chunks present + try: + completion_tokens = len([c for c in chunks if c]) + except Exception: + completion_tokens = None + + # Estimate prompt tokens from messages + if messages: + total_chars = 0 + try: + for message in messages: + if isinstance(message, dict) and "content" in message: + total_chars += len(str(message["content"])) + except Exception: + total_chars = 0 + prompt_tokens = max(1, total_chars // 4) if total_chars > 0 else 0 + else: + prompt_tokens = 0 + + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + + # Cost from headers if present; otherwise simple heuristic for some models + cost = latest_chunk_metadata.get("cost") + if cost is None and total_tokens and model_name: + ml = model_name.lower() + if "gpt-3.5-turbo" in ml: + cost = (prompt_tokens * 0.0005 / 1000.0) + (completion_tokens * 0.0015 / 1000.0) + elif "gpt-4" in ml: + # very rough heuristic, not guaranteed + cost = (prompt_tokens * 0.03 / 1_000_000.0) + (completion_tokens * 0.06 / 1_000_000.0) + + return completion_tokens, prompt_tokens, total_tokens, cost + except Exception: + return None, None, None, None + + +def detect_provider_from_response(response: Any, client: "Portkey", model_name: str) -> str: + """Detect provider for non-streaming responses.""" + # First: check Portkey headers on the client (authoritative) + provider = _provider_from_portkey_headers(client) + if provider: + return provider + # Next: check response metadata if any + try: + # Some SDKs attach response headers/metadata + if hasattr(response, "response_metadata") and isinstance(response.response_metadata, dict): + if "provider" in response.response_metadata: + return response.response_metadata["provider"] + if hasattr(response, "_response_headers"): + headers = getattr(response, "_response_headers") + if isinstance(headers, dict): + for k, v in headers.items(): + if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: + return str(v) + except Exception: + pass + # Fallback to model name heuristics + return detect_provider_from_model_name(model_name) + + +def detect_provider_from_chunk(chunk: Any, client: "Portkey", model_name: str) -> str: + """Detect provider for streaming chunks.""" + # First: check Portkey headers on the client + provider = _provider_from_portkey_headers(client) + if provider: + return provider + # Next: see if chunk exposes any metadata + try: + if hasattr(chunk, "response_metadata") and isinstance(chunk.response_metadata, dict): + if "provider" in chunk.response_metadata: + return chunk.response_metadata["provider"] + if hasattr(chunk, "_response_headers"): + headers = getattr(chunk, "_response_headers") + if isinstance(headers, dict): + for k, v in headers.items(): + if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: + return str(v) + except Exception: + pass + # Fallback to model name heuristics + return detect_provider_from_model_name(model_name) + + +def detect_provider_from_model_name(model_name: str) -> str: + """Detect provider from model name patterns (similar to LiteLLM tracer).""" + model_lower = (model_name or "").lower() + if model_lower.startswith(("gpt-", "o1-", "text-davinci", "text-curie", "text-babbage", "text-ada")): + return "OpenAI" + if model_lower.startswith(("claude-", "claude")): + return "Anthropic" + if "gemini" in model_lower or "palm" in model_lower: + return "Google" + if "llama" in model_lower or "meta-" in model_lower: + return "Meta" + if model_lower.startswith("mistral") or "mixtral" in model_lower: + return "Mistral" + if model_lower.startswith("command"): + return "Cohere" + return "Portkey" + + +def get_delta_from_chunk(chunk: Any) -> Any: + """Extract delta from chunk, handling different response formats.""" + try: + if hasattr(chunk, "choices") and chunk.choices: + choice = chunk.choices[0] + if hasattr(choice, "delta"): + return choice.delta + except Exception: + pass + return None + + +def _provider_from_portkey_headers(client: "Portkey") -> Optional[str]: + """Get provider from Portkey headers on the client.""" + header_sources = ("default_headers", "headers", "_default_headers", "_headers") + for attr in header_sources: + try: + headers = getattr(client, attr, None) + if isinstance(headers, dict): + for k, v in headers.items(): + if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: + return str(v) + except Exception: + continue + return None + From fd534506fbe375475d4ead1545410c4df8854241 Mon Sep 17 00:00:00 2001 From: Priyank <15610225+priyankinfinnov@users.noreply.github.com> Date: Thu, 13 Nov 2025 09:41:18 +0530 Subject: [PATCH 2/7] fixes and added tests --- .../lib/integrations/portkey_tracer.py | 72 ++- tests/test_integration_conditional_imports.py | 1 + tests/test_portkey_integration.py | 585 ++++++++++++++++++ 3 files changed, 632 insertions(+), 26 deletions(-) create mode 100644 tests/test_portkey_integration.py diff --git a/src/openlayer/lib/integrations/portkey_tracer.py b/src/openlayer/lib/integrations/portkey_tracer.py index 20eb75b6..4313ecef 100644 --- a/src/openlayer/lib/integrations/portkey_tracer.py +++ b/src/openlayer/lib/integrations/portkey_tracer.py @@ -4,7 +4,7 @@ import logging import time from functools import wraps -from typing import Any, Dict, Iterator, Optional, Union +from typing import Any, Dict, Iterator, Optional, Union, TYPE_CHECKING try: from portkey_ai import Portkey @@ -12,6 +12,9 @@ except ImportError: HAVE_PORTKEY = False +if TYPE_CHECKING: + from portkey_ai import Portkey + from ..tracing import tracer logger = logging.getLogger(__name__) @@ -92,16 +95,16 @@ def traced_create(*c_args, **c_kwargs): return handle_streaming_create( self, *c_args, - **c_kwargs, create_func=original_create, inference_id=inference_id, + **c_kwargs, ) return handle_non_streaming_create( self, *c_args, - **c_kwargs, create_func=original_create, inference_id=inference_id, + **c_kwargs, ) self.chat.completions.create = traced_create @@ -419,12 +422,12 @@ def extract_portkey_metadata(client: "Portkey") -> Dict[str, Any]: continue # Headers - possible_header_attrs = ("default_headers", "headers", "_default_headers", "_headers") + possible_header_attrs = ("default_headers", "headers", "_default_headers", "_headers", "custom_headers", "allHeaders") redacted: Dict[str, Any] = {} for attr in possible_header_attrs: try: headers = getattr(client, attr, None) - if isinstance(headers, dict): + if _is_dict_like(headers): for k, v in headers.items(): if isinstance(k, str) and k.lower().startswith("x-portkey-"): if k.lower() in {"x-portkey-api-key", "x-portkey-virtual-key"}: @@ -449,6 +452,8 @@ def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]: # Extract system fingerprint if available (OpenAI-compatible) if hasattr(unit, "system_fingerprint"): metadata["system_fingerprint"] = unit.system_fingerprint + if hasattr(unit, "service_tier"): + metadata["service_tier"] = unit.service_tier # Response headers may be present on the object headers_obj = None @@ -456,19 +461,20 @@ def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]: headers_obj = getattr(unit, "_response_headers") elif hasattr(unit, "response_headers"): headers_obj = getattr(unit, "response_headers") - if isinstance(headers_obj, dict): + elif hasattr(unit, "_headers"): + headers_obj = getattr(unit, "_headers") + + if _is_dict_like(headers_obj): headers = {str(k): v for k, v in headers_obj.items()} metadata["response_headers"] = headers # Known Portkey header hints (names are lower-cased defensively) lower = {k.lower(): v for k, v in headers.items()} - if "x-portkey-request-id" in lower: - metadata["portkey_request_id"] = lower["x-portkey-request-id"] - if "x-portkey-route" in lower: - metadata["portkey_route"] = lower["x-portkey-route"] - if "x-portkey-config" in lower: - metadata["portkey_config"] = lower["x-portkey-config"] + if "x-portkey-trace-id" in lower: + metadata["portkey_trace_id"] = lower["x-portkey-trace-id"] if "x-portkey-provider" in lower: metadata["provider"] = lower["x-portkey-provider"] + if "x-portkey-cache-status" in lower: + metadata["portkey_cache_status"] = lower["x-portkey-cache-status"] # Cost if surfaced by gateway if "x-portkey-cost" in lower: try: @@ -522,7 +528,7 @@ def extract_usage_from_chunk(chunk: Any) -> Dict[str, Optional[int]]: "completion_tokens": getattr(usage, "completion_tokens", None), } # Check if usage is a dictionary key - elif isinstance(hidden_params, dict) and "usage" in hidden_params: + elif _supports_membership_check(hidden_params) and "usage" in hidden_params: usage = hidden_params["usage"] if usage: return { @@ -534,7 +540,7 @@ def extract_usage_from_chunk(chunk: Any) -> Dict[str, Optional[int]]: # Check if chunk model dump has usage if hasattr(chunk, "model_dump"): chunk_dict = chunk.model_dump() - if isinstance(chunk_dict, dict) and "usage" in chunk_dict and chunk_dict["usage"]: + if _supports_membership_check(chunk_dict) and "usage" in chunk_dict and chunk_dict["usage"]: usage = chunk_dict["usage"] return { "total_tokens": usage.get("total_tokens", None), @@ -568,7 +574,7 @@ def calculate_streaming_usage_and_cost( # Priority 2: Look for usage embedded in final chunk dicts (if raw dicts) if isinstance(chunks, list): for chunk_data in reversed(chunks): - if isinstance(chunk_data, dict) and "usage" in chunk_data and chunk_data["usage"]: + if _supports_membership_check(chunk_data) and "usage" in chunk_data and chunk_data["usage"]: usage = chunk_data["usage"] if usage.get("total_tokens", 0) > 0: return ( @@ -587,7 +593,7 @@ def calculate_streaming_usage_and_cost( # Estimate completion tokens if isinstance(output_content, str): completion_tokens = max(1, len(output_content) // 4) - elif isinstance(output_content, dict): + elif _is_dict_like(output_content): json_str = json.dumps(output_content) if output_content else "{}" completion_tokens = max(1, len(json_str) // 4) else: @@ -602,7 +608,7 @@ def calculate_streaming_usage_and_cost( total_chars = 0 try: for message in messages: - if isinstance(message, dict) and "content" in message: + if _supports_membership_check(message) and "content" in message: total_chars += len(str(message["content"])) except Exception: total_chars = 0 @@ -618,12 +624,10 @@ def calculate_streaming_usage_and_cost( ml = model_name.lower() if "gpt-3.5-turbo" in ml: cost = (prompt_tokens * 0.0005 / 1000.0) + (completion_tokens * 0.0015 / 1000.0) - elif "gpt-4" in ml: - # very rough heuristic, not guaranteed - cost = (prompt_tokens * 0.03 / 1_000_000.0) + (completion_tokens * 0.06 / 1_000_000.0) return completion_tokens, prompt_tokens, total_tokens, cost - except Exception: + except Exception as e: + logger.error("Error calculating streaming usage: %s", e) return None, None, None, None @@ -636,12 +640,12 @@ def detect_provider_from_response(response: Any, client: "Portkey", model_name: # Next: check response metadata if any try: # Some SDKs attach response headers/metadata - if hasattr(response, "response_metadata") and isinstance(response.response_metadata, dict): + if hasattr(response, "response_metadata") and _is_dict_like(response.response_metadata): if "provider" in response.response_metadata: return response.response_metadata["provider"] if hasattr(response, "_response_headers"): headers = getattr(response, "_response_headers") - if isinstance(headers, dict): + if _is_dict_like(headers): for k, v in headers.items(): if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: return str(v) @@ -659,12 +663,12 @@ def detect_provider_from_chunk(chunk: Any, client: "Portkey", model_name: str) - return provider # Next: see if chunk exposes any metadata try: - if hasattr(chunk, "response_metadata") and isinstance(chunk.response_metadata, dict): + if hasattr(chunk, "response_metadata") and _is_dict_like(chunk.response_metadata): if "provider" in chunk.response_metadata: return chunk.response_metadata["provider"] if hasattr(chunk, "_response_headers"): headers = getattr(chunk, "_response_headers") - if isinstance(headers, dict): + if _is_dict_like(headers): for k, v in headers.items(): if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: return str(v) @@ -710,7 +714,7 @@ def _provider_from_portkey_headers(client: "Portkey") -> Optional[str]: for attr in header_sources: try: headers = getattr(client, attr, None) - if isinstance(headers, dict): + if _is_dict_like(headers): for k, v in headers.items(): if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: return str(v) @@ -718,3 +722,19 @@ def _provider_from_portkey_headers(client: "Portkey") -> Optional[str]: continue return None + +def _is_dict_like(obj: Any) -> bool: + """Check if an object is dict-like (has .items() method). + + This is more robust than isinstance(obj, dict) as it handles + custom dict-like objects (e.g., CaseInsensitiveDict, custom headers). + """ + return hasattr(obj, "items") and callable(getattr(obj, "items", None)) + + +def _supports_membership_check(obj: Any) -> bool: + """Check if an object supports membership testing (e.g., 'key in obj'). + + This checks for __contains__ method or if it's dict-like. + """ + return hasattr(obj, "__contains__") or _is_dict_like(obj) diff --git a/tests/test_integration_conditional_imports.py b/tests/test_integration_conditional_imports.py index f673b480..adb5bf9d 100644 --- a/tests/test_integration_conditional_imports.py +++ b/tests/test_integration_conditional_imports.py @@ -34,6 +34,7 @@ "oci_tracer": ["oci"], "langchain_callback": ["langchain", "langchain_core", "langchain_community"], "litellm_tracer": ["litellm"], + "portkey_tracer": ["portkey_ai"], } # Expected patterns for integration modules diff --git a/tests/test_portkey_integration.py b/tests/test_portkey_integration.py new file mode 100644 index 00000000..410c6b39 --- /dev/null +++ b/tests/test_portkey_integration.py @@ -0,0 +1,585 @@ +"""Test Portkey tracer integration.""" + +import json +from types import SimpleNamespace +from typing import Any, Dict +from unittest.mock import Mock, patch + +import pytest # type: ignore + + +class TestPortkeyIntegration: + """Test Portkey integration functionality.""" + + def test_import_without_portkey(self) -> None: + """Module should import even when Portkey is unavailable.""" + from openlayer.lib.integrations import portkey_tracer # noqa: F401 + + assert hasattr(portkey_tracer, "HAVE_PORTKEY") + + def test_trace_portkey_raises_import_error_without_dependency(self) -> None: + """trace_portkey should raise ImportError when Portkey is missing.""" + with patch("openlayer.lib.integrations.portkey_tracer.HAVE_PORTKEY", False): + from openlayer.lib.integrations.portkey_tracer import trace_portkey + + with pytest.raises(ImportError) as exc_info: # type: ignore + trace_portkey() + + message = str(exc_info.value) # type: ignore[attr-defined] + assert "Portkey library is not installed" in message + assert "pip install portkey-ai" in message + + def test_trace_portkey_patches_portkey_client(self) -> None: + """trace_portkey should wrap Portkey chat completions for tracing.""" + + class DummyPortkey: # pylint: disable=too-few-public-methods + """Lightweight Portkey stand-in used for patching behavior.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401 + completions = SimpleNamespace(create=Mock(name="original_create")) + self.chat = SimpleNamespace(completions=completions) + self._init_args = (args, kwargs) + self.original_create = completions.create + + with patch("openlayer.lib.integrations.portkey_tracer.HAVE_PORTKEY", True), patch( + "openlayer.lib.integrations.portkey_tracer.Portkey", DummyPortkey, create=True + ), patch( + "openlayer.lib.integrations.portkey_tracer.handle_non_streaming_create", + autospec=True, + ) as mock_non_streaming, patch( + "openlayer.lib.integrations.portkey_tracer.handle_streaming_create", + autospec=True, + ) as mock_streaming: + mock_non_streaming.return_value = "non-stream-result" + mock_streaming.return_value = "stream-result" + + from openlayer.lib.integrations.portkey_tracer import trace_portkey + + trace_portkey() + + client = DummyPortkey() + # Non-streaming + result_non_stream = client.chat.completions.create(messages=[{"role": "user", "content": "hi"}]) + assert result_non_stream == "non-stream-result" + assert mock_non_streaming.call_count == 1 + non_stream_kwargs = mock_non_streaming.call_args.kwargs + assert non_stream_kwargs["create_func"] is client.original_create + assert non_stream_kwargs["inference_id"] is None + + # Streaming path + result_stream = client.chat.completions.create( + messages=[{"role": "user", "content": "hi"}], stream=True, inference_id="inference-123" + ) + assert result_stream == "stream-result" + assert mock_streaming.call_count == 1 + stream_kwargs = mock_streaming.call_args.kwargs + assert stream_kwargs["create_func"] is client.original_create + assert stream_kwargs["inference_id"] == "inference-123" + + def test_detect_provider_from_model_name(self) -> None: + """Provider detection should match model naming heuristics.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider_from_model_name + + test_cases = [ + ("gpt-4", "OpenAI"), + ("Gpt-3.5-turbo", "OpenAI"), + ("claude-3-opus", "Anthropic"), + ("gemini-pro", "Google"), + ("meta-llama-3-70b", "Meta"), + ("mixtral-8x7b", "Mistral"), + ("command-r", "Cohere"), + ("unknown-model", "Portkey"), + ] + + for model_name, expected in test_cases: + assert detect_provider_from_model_name(model_name) == expected + + def test_get_model_parameters(self) -> None: + """Ensure OpenAI-compatible kwargs are extracted.""" + from openlayer.lib.integrations.portkey_tracer import get_model_parameters + + kwargs = { + "temperature": 0.5, + "top_p": 0.7, + "max_tokens": 256, + "n": 3, + "stream": True, + "stop": ["END"], + "presence_penalty": 0.1, + "frequency_penalty": -0.1, + "logit_bias": {"1": -1}, + "logprobs": True, + "top_logprobs": 5, + "parallel_tool_calls": False, + "seed": 123, + "response_format": {"type": "json_object"}, + "timeout": 42, + "api_base": "https://api.example.com", + "api_version": "2024-05-01", + } + + params = get_model_parameters(kwargs) + + expected = kwargs.copy() + assert params == expected + + def test_extract_portkey_metadata(self) -> None: + """Portkey metadata should redact sensitive headers and include base URL.""" + from openlayer.lib.integrations.portkey_tracer import extract_portkey_metadata + + client = SimpleNamespace( + base_url="https://gateway.portkey.ai", + headers={ + "X-Portkey-Api-Key": "secret", + "X-Portkey-Provider": "openai", + "Authorization": "Bearer ignored", + }, + ) + + metadata = extract_portkey_metadata(client) + + assert metadata["isPortkey"] is True + assert metadata["portkeyBaseUrl"] == "https://gateway.portkey.ai" + assert metadata["portkeyHeaders"]["X-Portkey-Api-Key"] == "***" + assert metadata["portkeyHeaders"]["X-Portkey-Provider"] == "openai" + assert "Authorization" not in metadata["portkeyHeaders"] + + def test_extract_portkey_unit_metadata(self) -> None: + """Unit metadata should capture headers, cost, and provider hints.""" + from openlayer.lib.integrations.portkey_tracer import extract_portkey_unit_metadata + + unit = SimpleNamespace( + system_fingerprint="fingerprint-123", + _response_headers={ + "x-portkey-trace-id": "trace-1", + "x-portkey-provider": "anthropic", + "x-portkey-cache-status": "HIT", + "x-portkey-cost": "0.45", + "content-type": "application/json", + }, + ) + + metadata = extract_portkey_unit_metadata(unit, "claude-3-opus") + + assert metadata["system_fingerprint"] == "fingerprint-123" + assert metadata["portkey_trace_id"] == "trace-1" + assert metadata["provider"] == "anthropic" + assert metadata["portkey_cache_status"] == "HIT" + assert metadata["cost"] == pytest.approx(0.45) + assert metadata["portkey_model"] == "claude-3-opus" + assert metadata["response_headers"]["content-type"] == "application/json" + + def test_extract_portkey_unit_metadata_with_dict_like_headers(self) -> None: + """Unit metadata should work with dict-like objects (not just dicts).""" + from openlayer.lib.integrations.portkey_tracer import extract_portkey_unit_metadata + + # Create a dict-like object (has .items() but not isinstance(dict)) + class DictLikeHeaders: + def __init__(self): + self._data = { + "x-portkey-trace-id": "trace-2", + "x-portkey-provider": "openai", + "x-portkey-cache-status": "MISS", + } + + def items(self): + return self._data.items() + + unit = SimpleNamespace( + _response_headers=DictLikeHeaders(), + ) + + metadata = extract_portkey_unit_metadata(unit, "gpt-4") + + assert metadata["portkey_trace_id"] == "trace-2" + assert metadata["provider"] == "openai" + assert metadata["portkey_cache_status"] == "MISS" + + def test_extract_usage_from_response(self) -> None: + """Usage extraction should read OpenAI-style usage objects.""" + from openlayer.lib.integrations.portkey_tracer import extract_usage_from_response + + usage = SimpleNamespace(total_tokens=50, prompt_tokens=20, completion_tokens=30) + response = SimpleNamespace(usage=usage) + + assert extract_usage_from_response(response) == { + "total_tokens": 50, + "prompt_tokens": 20, + "completion_tokens": 30, + } + + response_no_usage = SimpleNamespace() + assert extract_usage_from_response(response_no_usage) == { + "total_tokens": None, + "prompt_tokens": None, + "completion_tokens": None, + } + + def test_extract_usage_from_chunk(self) -> None: + """Usage data should be derived from multiple potential chunk attributes.""" + from openlayer.lib.integrations.portkey_tracer import extract_usage_from_chunk + + chunk_direct = SimpleNamespace( + usage=SimpleNamespace(total_tokens=120, prompt_tokens=40, completion_tokens=80) + ) + assert extract_usage_from_chunk(chunk_direct) == { + "total_tokens": 120, + "prompt_tokens": 40, + "completion_tokens": 80, + } + + chunk_hidden = SimpleNamespace( + _hidden_params={"usage": {"total_tokens": 30, "prompt_tokens": 10, "completion_tokens": 20}} + ) + assert extract_usage_from_chunk(chunk_hidden) == { + "total_tokens": 30, + "prompt_tokens": 10, + "completion_tokens": 20, + } + + class ChunkWithModelDump: # pylint: disable=too-few-public-methods + def model_dump(self) -> Dict[str, Any]: + return {"usage": {"total_tokens": 12, "prompt_tokens": 5, "completion_tokens": 7}} + + assert extract_usage_from_chunk(ChunkWithModelDump()) == { + "total_tokens": 12, + "prompt_tokens": 5, + "completion_tokens": 7, + } + + def test_calculate_streaming_usage_and_cost_with_actual_usage(self) -> None: + """Actual usage data should be returned when available.""" + from openlayer.lib.integrations.portkey_tracer import calculate_streaming_usage_and_cost + + latest_usage = {"total_tokens": 100, "prompt_tokens": 40, "completion_tokens": 60} + latest_metadata = {"cost": 0.99} + + result = calculate_streaming_usage_and_cost( + chunks=[], + messages=[], + output_content="", + model_name="gpt-4", + latest_usage_data=latest_usage, + latest_chunk_metadata=latest_metadata, + ) + + assert result == (60, 40, 100, 0.99) + + def test_calculate_streaming_usage_and_cost_fallback_estimation(self) -> None: + """Fallback estimation should approximate tokens and cost when usage is missing.""" + from openlayer.lib.integrations.portkey_tracer import calculate_streaming_usage_and_cost + + output_content = "Generated answer text." + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Tell me something interesting."}, + ] + + completion_tokens, prompt_tokens, total_tokens, cost = calculate_streaming_usage_and_cost( + chunks=[{"usage": None}], + messages=messages, + output_content=output_content, + model_name="gpt-3.5-turbo", + latest_usage_data={"total_tokens": None, "prompt_tokens": None, "completion_tokens": None}, + latest_chunk_metadata={}, + ) + + assert completion_tokens >= 1 + assert prompt_tokens >= 1 + assert total_tokens == (completion_tokens or 0) + (prompt_tokens or 0) + assert cost is not None + assert cost >= 0 + + def test_detect_provider_from_response_prefers_headers(self) -> None: + """Provider detection should prioritize Portkey headers.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider_from_response + + client = SimpleNamespace() + response = SimpleNamespace() + + with patch( + "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider" + ): + assert detect_provider_from_response(response, client, "gpt-4") == "header-provider" + + def test_detect_provider_from_chunk_prefers_headers(self) -> None: + """Provider detection from chunk should prioritize header-derived values.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider_from_chunk + + client = SimpleNamespace() + chunk = SimpleNamespace() + + with patch( + "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider" + ): + assert detect_provider_from_chunk(chunk, client, "gpt-4") == "header-provider" + + def test_detect_provider_from_response_fallback(self) -> None: + """Provider detection should fall back to response metadata or model name.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider_from_response + + client = SimpleNamespace(headers={"x-portkey-provider": "openai"}) + response = SimpleNamespace( + _response_headers={"X-Portkey-Provider": "anthropic"}, + response_metadata={"provider": "anthropic"}, + ) + + with patch( + "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None + ): + assert detect_provider_from_response(response, client, "mistral-7b") == "anthropic" + + def test_detect_provider_from_chunk_fallback(self) -> None: + """Chunk provider detection should fall back gracefully.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider_from_chunk + + chunk = SimpleNamespace( + response_metadata={"provider": "cohere"}, + _response_headers={"X-Portkey-Provider": "cohere"}, + ) + client = SimpleNamespace() + + with patch( + "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None + ): + assert detect_provider_from_chunk(chunk, client, "command-r") == "cohere" + + def test_provider_from_portkey_headers(self) -> None: + """Header helper should identify provider values on the client.""" + from openlayer.lib.integrations.portkey_tracer import _provider_from_portkey_headers + + client = SimpleNamespace( + default_headers={"X-Portkey-Provider": "openai"}, + headers={"X-Portkey-Provider": "anthropic"}, + ) + + assert _provider_from_portkey_headers(client) == "openai" + + def test_parse_non_streaming_output_data(self) -> None: + """Output parsing should support content, function calls, and tool calls.""" + from openlayer.lib.integrations.portkey_tracer import parse_non_streaming_output_data + + # Content message + response_content = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", function_call=None, tool_calls=None))] + ) + assert parse_non_streaming_output_data(response_content) == "Hello!" + + # Function call + response_function = SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace( + content=None, + function_call=SimpleNamespace(name="do_something", arguments=json.dumps({"value": 1})), + tool_calls=None, + ) + ) + ] + ) + assert parse_non_streaming_output_data(response_function) == {"name": "do_something", "arguments": {"value": 1}} + + # Tool call + response_tool = SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace( + content=None, + function_call=None, + tool_calls=[ + SimpleNamespace( + function=SimpleNamespace(name="call_tool", arguments=json.dumps({"value": 2})) + ) + ], + ) + ) + ] + ) + assert parse_non_streaming_output_data(response_tool) == {"name": "call_tool", "arguments": {"value": 2}} + + def test_create_trace_args(self) -> None: + """Trace argument helper should include optional id and cost.""" + from openlayer.lib.integrations.portkey_tracer import create_trace_args + + args = create_trace_args( + end_time=1.0, + inputs={"prompt": []}, + output="response", + latency=123.0, + tokens=10, + prompt_tokens=4, + completion_tokens=6, + model="gpt-4", + id="trace-id", + cost=0.42, + ) + + assert args["id"] == "trace-id" + assert args["cost"] == 0.42 + assert args["metadata"] == {} + + def test_add_to_trace_uses_provider_metadata(self) -> None: + """add_to_trace should pass provider metadata through to tracer.""" + from openlayer.lib.integrations.portkey_tracer import add_to_trace + + with patch( + "openlayer.lib.integrations.portkey_tracer.tracer.add_chat_completion_step_to_trace" + ) as mock_add: + add_to_trace( + end_time=1.0, + inputs={}, + output=None, + latency=10.0, + tokens=None, + prompt_tokens=None, + completion_tokens=None, + model="model", + metadata={}, + ) + + _, kwargs = mock_add.call_args + assert kwargs["provider"] == "Portkey" + assert kwargs["name"] == "Portkey Chat Completion" + + add_to_trace( + end_time=2.0, + inputs={}, + output=None, + latency=5.0, + tokens=None, + prompt_tokens=None, + completion_tokens=None, + model="model", + metadata={"provider": "OpenAI"}, + ) + + assert mock_add.call_count == 2 + assert mock_add.call_args.kwargs["provider"] == "OpenAI" + + def test_handle_streaming_create_delegates_to_stream_chunks(self) -> None: + """handle_streaming_create should call the original create and stream_chunks.""" + from openlayer.lib.integrations.portkey_tracer import handle_streaming_create + + client = SimpleNamespace() + create_func = Mock(return_value=iter(["chunk"])) + + with patch( + "openlayer.lib.integrations.portkey_tracer.stream_chunks", return_value=iter(["chunk"]) + ) as mock_stream_chunks: + result_iterator = handle_streaming_create( + client, + "arg-1", + create_func=create_func, + inference_id="stream-id", + foo="bar", + ) + + assert list(result_iterator) == ["chunk"] + create_func.assert_called_once_with("arg-1", foo="bar") + mock_stream_chunks.assert_called_once() + stream_kwargs = mock_stream_chunks.call_args.kwargs + assert stream_kwargs["client"] is client + assert stream_kwargs["inference_id"] == "stream-id" + assert stream_kwargs["kwargs"] == {"foo": "bar"} + assert stream_kwargs["chunks"] is create_func.return_value + + def test_stream_chunks_traces_completion(self) -> None: + """stream_chunks should yield all chunks and record a traced step.""" + from openlayer.lib.integrations.portkey_tracer import stream_chunks + + chunk_a = object() + chunk_b = object() + chunks = [chunk_a, chunk_b] + kwargs = {"messages": [{"role": "user", "content": "hello"}]} + client = SimpleNamespace() + + with patch( + "openlayer.lib.integrations.portkey_tracer.add_to_trace", autospec=True + ) as mock_add_to_trace, patch( + "openlayer.lib.integrations.portkey_tracer.extract_usage_from_chunk", autospec=True + ) as mock_usage, patch( + "openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata", autospec=True + ) as mock_unit_metadata, patch( + "openlayer.lib.integrations.portkey_tracer.detect_provider_from_chunk", autospec=True + ) as mock_detect_provider, patch( + "openlayer.lib.integrations.portkey_tracer.get_delta_from_chunk", autospec=True + ) as mock_delta, patch( + "openlayer.lib.integrations.portkey_tracer.calculate_streaming_usage_and_cost", autospec=True + ) as mock_calc, patch( + "openlayer.lib.integrations.portkey_tracer.extract_portkey_metadata", autospec=True + ) as mock_client_metadata, patch( + "openlayer.lib.integrations.portkey_tracer.time.time", side_effect=[100.0, 100.05, 100.2] + ): + mock_usage.side_effect = [ + {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None}, + {"total_tokens": 10, "prompt_tokens": 4, "completion_tokens": 6}, + ] + mock_unit_metadata.side_effect = [{}, {"cost": 0.1}] + mock_detect_provider.side_effect = ["OpenAI", "OpenAI"] + mock_delta.side_effect = [ + SimpleNamespace(content="Hello ", function_call=None, tool_calls=None), + SimpleNamespace(content="World", function_call=None, tool_calls=None), + ] + mock_calc.return_value = (6, 4, 10, 0.1) + mock_client_metadata.return_value = {"portkeyBaseUrl": "https://gateway"} + + yielded = list( + stream_chunks( + chunks=iter(chunks), + kwargs=kwargs, + client=client, + inference_id="trace-123", + ) + ) + + assert yielded == chunks + mock_add_to_trace.assert_called_once() + trace_kwargs = mock_add_to_trace.call_args.kwargs + assert trace_kwargs["metadata"]["provider"] == "OpenAI" + assert trace_kwargs["metadata"]["portkeyBaseUrl"] == "https://gateway" + assert trace_kwargs["id"] == "trace-123" + assert trace_kwargs["tokens"] == 10 + assert trace_kwargs["latency"] == pytest.approx(200.0) + + def test_handle_non_streaming_create_traces_completion(self) -> None: + """handle_non_streaming_create should record a traced step for completions.""" + from openlayer.lib.integrations.portkey_tracer import handle_non_streaming_create + + response = SimpleNamespace(model="gpt-4", system_fingerprint="fp-1") + client = SimpleNamespace() + create_func = Mock(return_value=response) + + with patch( + "openlayer.lib.integrations.portkey_tracer.parse_non_streaming_output_data", return_value="output" + ), patch( + "openlayer.lib.integrations.portkey_tracer.extract_usage_from_response", + return_value={"total_tokens": 10, "prompt_tokens": 4, "completion_tokens": 6}, + ), patch( + "openlayer.lib.integrations.portkey_tracer.detect_provider_from_response", return_value="OpenAI" + ), patch( + "openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata", + return_value={"cost": 0.25}, + ), patch( + "openlayer.lib.integrations.portkey_tracer.extract_portkey_metadata", + return_value={"portkeyHeaders": {"X-Portkey-Provider": "openai"}}, + ), patch( + "openlayer.lib.integrations.portkey_tracer.add_to_trace" + ) as mock_add_to_trace, patch( + "openlayer.lib.integrations.portkey_tracer.time.time", side_effect=[10.0, 10.2] + ): + result = handle_non_streaming_create( + client, + create_func=create_func, + inference_id="trace-xyz", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert result is response + mock_add_to_trace.assert_called_once() + trace_kwargs = mock_add_to_trace.call_args.kwargs + assert trace_kwargs["id"] == "trace-xyz" + assert trace_kwargs["metadata"]["provider"] == "OpenAI" + assert trace_kwargs["metadata"]["cost"] == 0.25 + assert trace_kwargs["metadata"]["portkeyHeaders"]["X-Portkey-Provider"] == "openai" + + From 83ec6a174d982f68ffa3fa72adc3bb3344754e4c Mon Sep 17 00:00:00 2001 From: Priyank <15610225+priyankinfinnov@users.noreply.github.com> Date: Thu, 13 Nov 2025 09:52:52 +0530 Subject: [PATCH 3/7] doc string updted --- src/openlayer/lib/__init__.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/openlayer/lib/__init__.py b/src/openlayer/lib/__init__.py index 447afd7d..d776dc23 100644 --- a/src/openlayer/lib/__init__.py +++ b/src/openlayer/lib/__init__.py @@ -197,6 +197,25 @@ def trace_portkey(): This function patches Portkey's chat.completions.create to automatically trace all OpenAI-compatible completions routed via the Portkey AI Gateway. + + Example: + >>> from portkey_ai import Portkey + >>> from openlayer.lib import trace_portkey + >>> # Enable openlayer tracing for all Portkey completions + >>> trace_portkey() + >>> # Basic portkey client initialization + >>> portkey = Portkey( + >>> api_key = os.environ['PORTKEY_API_KEY'], + >>> config = "YOUR_PORTKEY_CONFIG_ID", # optional your portkey config id + >>> ) + >>> # use portkey normally - tracing happens automatically + >>> response = portkey.chat.completions.create( + >>> #model = "@YOUR_PORTKEY_SLUG/YOUR_MODEL_NAME", # optional if giving config + >>> messages = [ + >>> {"role": "system", "content": "You are a helpful assistant."}, + >>> {"role": "user", "content": "Write a poem on Argentina, least 100 words."} + >>> ] + >>> ) """ # pylint: disable=import-outside-toplevel try: From a4c1d5a5e3a0ae387f45528da22e8bba7a6ecd7b Mon Sep 17 00:00:00 2001 From: Priyank <15610225+priyankinfinnov@users.noreply.github.com> Date: Thu, 13 Nov 2025 09:58:57 +0530 Subject: [PATCH 4/7] updates docstrings --- .../lib/integrations/portkey_tracer.py | 49 +++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/openlayer/lib/integrations/portkey_tracer.py b/src/openlayer/lib/integrations/portkey_tracer.py index 4313ecef..3c960a53 100644 --- a/src/openlayer/lib/integrations/portkey_tracer.py +++ b/src/openlayer/lib/integrations/portkey_tracer.py @@ -46,6 +46,7 @@ def trace_portkey() -> None: >>> from portkey_ai import Portkey >>> from openlayer.lib import trace_portkey >>> + >>> # Enable tracing >>> trace_portkey() >>> >>> # Use Portkey normally - tracing happens automatically @@ -124,7 +125,27 @@ def handle_streaming_create( inference_id: Optional[str] = None, **kwargs, ) -> Iterator[Any]: - """Handles streaming chat.completions.create routed via Portkey.""" + """ + Handles streaming chat.completions.create routed via Portkey. + + Parameters + ---------- + client : Portkey + The Portkey client instance making the request. + *args : + Positional arguments passed to the create function. + create_func : callable + The create function to call (typically chat.completions.create). + inference_id : Optional[str], default None + Optional inference ID for tracking this request. + **kwargs : + Additional keyword arguments forwarded to create_func. + + Returns + ------- + Iterator[Any] + A generator that yields the chunks of the completion. + """ # Portkey is OpenAI-compatible; request and chunks follow OpenAI spec # create_func is a bound method; do not pass client again chunks = create_func(*args, **kwargs) @@ -265,7 +286,27 @@ def handle_non_streaming_create( inference_id: Optional[str] = None, **kwargs, ) -> Any: - """Handles non-streaming chat.completions.create routed via Portkey.""" + """ + Handles non-streaming chat.completions.create routed via Portkey. + + Parameters + ---------- + client : Portkey + The Portkey client instance used for routing the request. + *args : + Positional arguments for the create function. + create_func : callable + The function used to create the chat completion. This is a bound method, so do not pass client again. + inference_id : Optional[str], optional + A unique identifier for the inference or trace, by default None. + **kwargs : + Additional keyword arguments passed to the create function (e.g., "messages", "model", etc.). + + Returns + ------- + Any + The completion response as returned by the create function. + """ start_time = time.time() # create_func is a bound method; do not pass client again response = create_func(*args, **kwargs) @@ -560,7 +601,7 @@ def calculate_streaming_usage_and_cost( latest_usage_data: Dict[str, Optional[int]], latest_chunk_metadata: Dict[str, Any], ): - """Calculate usage and cost at the end of streaming (LiteLLM-style logic).""" + """Calculate usage and cost at the end of streaming.""" try: # Priority 1: Actual usage provided in chunks if latest_usage_data and latest_usage_data.get("total_tokens") and latest_usage_data.get("total_tokens") > 0: @@ -679,7 +720,7 @@ def detect_provider_from_chunk(chunk: Any, client: "Portkey", model_name: str) - def detect_provider_from_model_name(model_name: str) -> str: - """Detect provider from model name patterns (similar to LiteLLM tracer).""" + """Detect provider from model name patterns.""" model_lower = (model_name or "").lower() if model_lower.startswith(("gpt-", "o1-", "text-davinci", "text-curie", "text-babbage", "text-ada")): return "OpenAI" From 23618dabdcbf35c57d16d37836183884c3ca267e Mon Sep 17 00:00:00 2001 From: Priyank <15610225+priyankinfinnov@users.noreply.github.com> Date: Thu, 13 Nov 2025 10:22:23 +0530 Subject: [PATCH 5/7] refactoring to increase re-usability of code --- .../lib/integrations/portkey_tracer.py | 134 ++++++++++-------- tests/test_portkey_integration.py | 38 ++--- 2 files changed, 90 insertions(+), 82 deletions(-) diff --git a/src/openlayer/lib/integrations/portkey_tracer.py b/src/openlayer/lib/integrations/portkey_tracer.py index 3c960a53..4d36a7a2 100644 --- a/src/openlayer/lib/integrations/portkey_tracer.py +++ b/src/openlayer/lib/integrations/portkey_tracer.py @@ -185,12 +185,12 @@ def stream_chunks( if i == 0: first_token_time = time.time() # Try to detect provider at first chunk - provider = detect_provider_from_chunk(chunk, client, model_name) + provider = detect_provider(chunk, client, model_name) if i > 0: num_of_completion_tokens = i + 1 # Extract usage from chunk if available - chunk_usage = extract_usage_from_chunk(chunk) + chunk_usage = extract_usage(chunk) if any(v is not None for v in chunk_usage.values()): latest_usage_data = chunk_usage @@ -317,9 +317,9 @@ def handle_non_streaming_create( output_data = parse_non_streaming_output_data(response) # Usage (if provided by upstream provider via Portkey) - usage_data = extract_usage_from_response(response) + usage_data = extract_usage(response) model_name = getattr(response, "model", kwargs.get("model", "unknown")) - provider = detect_provider_from_response(response, client, model_name) + provider = detect_provider(response, client, model_name) extra_metadata = extract_portkey_unit_metadata(response, model_name) cost = extra_metadata.get("cost", None) @@ -530,36 +530,39 @@ def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]: return metadata -def extract_usage_from_response(response: Any) -> Dict[str, Optional[int]]: - """Extract usage from a non-streaming response.""" - try: - if hasattr(response, "usage") and response.usage is not None: - usage = response.usage - return { - "total_tokens": getattr(usage, "total_tokens", None), - "prompt_tokens": getattr(usage, "prompt_tokens", None), - "completion_tokens": getattr(usage, "completion_tokens", None), - } - except Exception: - pass - return {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None} - - -def extract_usage_from_chunk(chunk: Any) -> Dict[str, Optional[int]]: - """Extract usage from a streaming chunk if present.""" +def extract_usage(obj: Any) -> Dict[str, Optional[int]]: + """Extract usage from a response or chunk object. + + This function attempts to extract token usage information from various + locations where it might be stored, including: + - Direct `usage` attribute + - `_hidden_params` (for streaming chunks) + - `model_dump()` dictionary (for streaming chunks) + + Parameters + ---------- + obj : Any + The response or chunk object to extract usage from. + + Returns + ------- + Dict[str, Optional[int]] + Dictionary with keys: total_tokens, prompt_tokens, completion_tokens. + Values are None if usage information is not found. + """ try: - # Check for usage attribute - if hasattr(chunk, "usage") and chunk.usage is not None: - usage = chunk.usage + # Check for direct usage attribute (works for both response and chunk) + if hasattr(obj, "usage") and obj.usage is not None: + usage = obj.usage return { "total_tokens": getattr(usage, "total_tokens", None), "prompt_tokens": getattr(usage, "prompt_tokens", None), "completion_tokens": getattr(usage, "completion_tokens", None), } - # Check for usage in _hidden_params (if SDK stores it there) - if hasattr(chunk, "_hidden_params"): - hidden_params = chunk._hidden_params + # Check for usage in _hidden_params (primarily for streaming chunks) + if hasattr(obj, "_hidden_params"): + hidden_params = obj._hidden_params # Check if usage is a direct attribute if hasattr(hidden_params, "usage") and hidden_params.usage is not None: usage = hidden_params.usage @@ -578,11 +581,11 @@ def extract_usage_from_chunk(chunk: Any) -> Dict[str, Optional[int]]: "completion_tokens": usage.get("completion_tokens", None), } - # Check if chunk model dump has usage - if hasattr(chunk, "model_dump"): - chunk_dict = chunk.model_dump() - if _supports_membership_check(chunk_dict) and "usage" in chunk_dict and chunk_dict["usage"]: - usage = chunk_dict["usage"] + # Check if object model dump has usage (primarily for streaming chunks) + if hasattr(obj, "model_dump"): + obj_dict = obj.model_dump() + if _supports_membership_check(obj_dict) and "usage" in obj_dict and obj_dict["usage"]: + usage = obj_dict["usage"] return { "total_tokens": usage.get("total_tokens", None), "prompt_tokens": usage.get("prompt_tokens", None), @@ -672,49 +675,54 @@ def calculate_streaming_usage_and_cost( return None, None, None, None -def detect_provider_from_response(response: Any, client: "Portkey", model_name: str) -> str: - """Detect provider for non-streaming responses.""" - # First: check Portkey headers on the client (authoritative) - provider = _provider_from_portkey_headers(client) - if provider: - return provider - # Next: check response metadata if any +def _extract_provider_from_object(obj: Any) -> Optional[str]: + """Extract provider from a response or chunk object. + + Checks response_metadata and _response_headers for provider information. + Returns None if no provider is found. + """ try: - # Some SDKs attach response headers/metadata - if hasattr(response, "response_metadata") and _is_dict_like(response.response_metadata): - if "provider" in response.response_metadata: - return response.response_metadata["provider"] - if hasattr(response, "_response_headers"): - headers = getattr(response, "_response_headers") + # Check response_metadata + if hasattr(obj, "response_metadata") and _is_dict_like(obj.response_metadata): + if "provider" in obj.response_metadata: + return obj.response_metadata["provider"] + # Check _response_headers + if hasattr(obj, "_response_headers"): + headers = getattr(obj, "_response_headers") if _is_dict_like(headers): for k, v in headers.items(): if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: return str(v) except Exception: pass - # Fallback to model name heuristics - return detect_provider_from_model_name(model_name) + return None -def detect_provider_from_chunk(chunk: Any, client: "Portkey", model_name: str) -> str: - """Detect provider for streaming chunks.""" - # First: check Portkey headers on the client +def detect_provider(obj: Any, client: "Portkey", model_name: str) -> str: + """Detect provider from a response or chunk object. + + Parameters + ---------- + obj : Any + The response or chunk object to extract provider information from. + client : Portkey + The Portkey client instance. + model_name : str + The model name to use as a fallback for provider detection. + + Returns + ------- + str + The detected provider name. + """ + # First: check Portkey headers on the client (authoritative) provider = _provider_from_portkey_headers(client) if provider: return provider - # Next: see if chunk exposes any metadata - try: - if hasattr(chunk, "response_metadata") and _is_dict_like(chunk.response_metadata): - if "provider" in chunk.response_metadata: - return chunk.response_metadata["provider"] - if hasattr(chunk, "_response_headers"): - headers = getattr(chunk, "_response_headers") - if _is_dict_like(headers): - for k, v in headers.items(): - if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: - return str(v) - except Exception: - pass + # Next: check object metadata if any + provider = _extract_provider_from_object(obj) + if provider: + return provider # Fallback to model name heuristics return detect_provider_from_model_name(model_name) diff --git a/tests/test_portkey_integration.py b/tests/test_portkey_integration.py index 410c6b39..c9d2de5e 100644 --- a/tests/test_portkey_integration.py +++ b/tests/test_portkey_integration.py @@ -197,19 +197,19 @@ def items(self): def test_extract_usage_from_response(self) -> None: """Usage extraction should read OpenAI-style usage objects.""" - from openlayer.lib.integrations.portkey_tracer import extract_usage_from_response + from openlayer.lib.integrations.portkey_tracer import extract_usage usage = SimpleNamespace(total_tokens=50, prompt_tokens=20, completion_tokens=30) response = SimpleNamespace(usage=usage) - assert extract_usage_from_response(response) == { + assert extract_usage(response) == { "total_tokens": 50, "prompt_tokens": 20, "completion_tokens": 30, } response_no_usage = SimpleNamespace() - assert extract_usage_from_response(response_no_usage) == { + assert extract_usage(response_no_usage) == { "total_tokens": None, "prompt_tokens": None, "completion_tokens": None, @@ -217,12 +217,12 @@ def test_extract_usage_from_response(self) -> None: def test_extract_usage_from_chunk(self) -> None: """Usage data should be derived from multiple potential chunk attributes.""" - from openlayer.lib.integrations.portkey_tracer import extract_usage_from_chunk + from openlayer.lib.integrations.portkey_tracer import extract_usage chunk_direct = SimpleNamespace( usage=SimpleNamespace(total_tokens=120, prompt_tokens=40, completion_tokens=80) ) - assert extract_usage_from_chunk(chunk_direct) == { + assert extract_usage(chunk_direct) == { "total_tokens": 120, "prompt_tokens": 40, "completion_tokens": 80, @@ -231,7 +231,7 @@ def test_extract_usage_from_chunk(self) -> None: chunk_hidden = SimpleNamespace( _hidden_params={"usage": {"total_tokens": 30, "prompt_tokens": 10, "completion_tokens": 20}} ) - assert extract_usage_from_chunk(chunk_hidden) == { + assert extract_usage(chunk_hidden) == { "total_tokens": 30, "prompt_tokens": 10, "completion_tokens": 20, @@ -241,7 +241,7 @@ class ChunkWithModelDump: # pylint: disable=too-few-public-methods def model_dump(self) -> Dict[str, Any]: return {"usage": {"total_tokens": 12, "prompt_tokens": 5, "completion_tokens": 7}} - assert extract_usage_from_chunk(ChunkWithModelDump()) == { + assert extract_usage(ChunkWithModelDump()) == { "total_tokens": 12, "prompt_tokens": 5, "completion_tokens": 7, @@ -292,7 +292,7 @@ def test_calculate_streaming_usage_and_cost_fallback_estimation(self) -> None: def test_detect_provider_from_response_prefers_headers(self) -> None: """Provider detection should prioritize Portkey headers.""" - from openlayer.lib.integrations.portkey_tracer import detect_provider_from_response + from openlayer.lib.integrations.portkey_tracer import detect_provider client = SimpleNamespace() response = SimpleNamespace() @@ -300,11 +300,11 @@ def test_detect_provider_from_response_prefers_headers(self) -> None: with patch( "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider" ): - assert detect_provider_from_response(response, client, "gpt-4") == "header-provider" + assert detect_provider(response, client, "gpt-4") == "header-provider" def test_detect_provider_from_chunk_prefers_headers(self) -> None: """Provider detection from chunk should prioritize header-derived values.""" - from openlayer.lib.integrations.portkey_tracer import detect_provider_from_chunk + from openlayer.lib.integrations.portkey_tracer import detect_provider client = SimpleNamespace() chunk = SimpleNamespace() @@ -312,11 +312,11 @@ def test_detect_provider_from_chunk_prefers_headers(self) -> None: with patch( "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider" ): - assert detect_provider_from_chunk(chunk, client, "gpt-4") == "header-provider" + assert detect_provider(chunk, client, "gpt-4") == "header-provider" def test_detect_provider_from_response_fallback(self) -> None: """Provider detection should fall back to response metadata or model name.""" - from openlayer.lib.integrations.portkey_tracer import detect_provider_from_response + from openlayer.lib.integrations.portkey_tracer import detect_provider client = SimpleNamespace(headers={"x-portkey-provider": "openai"}) response = SimpleNamespace( @@ -327,11 +327,11 @@ def test_detect_provider_from_response_fallback(self) -> None: with patch( "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None ): - assert detect_provider_from_response(response, client, "mistral-7b") == "anthropic" + assert detect_provider(response, client, "mistral-7b") == "anthropic" def test_detect_provider_from_chunk_fallback(self) -> None: """Chunk provider detection should fall back gracefully.""" - from openlayer.lib.integrations.portkey_tracer import detect_provider_from_chunk + from openlayer.lib.integrations.portkey_tracer import detect_provider chunk = SimpleNamespace( response_metadata={"provider": "cohere"}, @@ -342,7 +342,7 @@ def test_detect_provider_from_chunk_fallback(self) -> None: with patch( "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None ): - assert detect_provider_from_chunk(chunk, client, "command-r") == "cohere" + assert detect_provider(chunk, client, "command-r") == "cohere" def test_provider_from_portkey_headers(self) -> None: """Header helper should identify provider values on the client.""" @@ -496,11 +496,11 @@ def test_stream_chunks_traces_completion(self) -> None: with patch( "openlayer.lib.integrations.portkey_tracer.add_to_trace", autospec=True ) as mock_add_to_trace, patch( - "openlayer.lib.integrations.portkey_tracer.extract_usage_from_chunk", autospec=True + "openlayer.lib.integrations.portkey_tracer.extract_usage", autospec=True ) as mock_usage, patch( "openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata", autospec=True ) as mock_unit_metadata, patch( - "openlayer.lib.integrations.portkey_tracer.detect_provider_from_chunk", autospec=True + "openlayer.lib.integrations.portkey_tracer.detect_provider", autospec=True ) as mock_detect_provider, patch( "openlayer.lib.integrations.portkey_tracer.get_delta_from_chunk", autospec=True ) as mock_delta, patch( @@ -552,10 +552,10 @@ def test_handle_non_streaming_create_traces_completion(self) -> None: with patch( "openlayer.lib.integrations.portkey_tracer.parse_non_streaming_output_data", return_value="output" ), patch( - "openlayer.lib.integrations.portkey_tracer.extract_usage_from_response", + "openlayer.lib.integrations.portkey_tracer.extract_usage", return_value={"total_tokens": 10, "prompt_tokens": 4, "completion_tokens": 6}, ), patch( - "openlayer.lib.integrations.portkey_tracer.detect_provider_from_response", return_value="OpenAI" + "openlayer.lib.integrations.portkey_tracer.detect_provider", return_value="OpenAI" ), patch( "openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata", return_value={"cost": 0.25}, From edf953a21b5f546edde9df6921653af086872c0c Mon Sep 17 00:00:00 2001 From: Priyank <15610225+priyankinfinnov@users.noreply.github.com> Date: Thu, 13 Nov 2025 10:50:14 +0530 Subject: [PATCH 6/7] fixing headers --- .../lib/integrations/portkey_tracer.py | 27 ++++++------------- tests/test_portkey_integration.py | 20 +++++++------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/src/openlayer/lib/integrations/portkey_tracer.py b/src/openlayer/lib/integrations/portkey_tracer.py index 4d36a7a2..a72dd770 100644 --- a/src/openlayer/lib/integrations/portkey_tracer.py +++ b/src/openlayer/lib/integrations/portkey_tracer.py @@ -194,7 +194,7 @@ def stream_chunks( if any(v is not None for v in chunk_usage.values()): latest_usage_data = chunk_usage - # Update metadata from latest chunk (headers/cost/etc.) + # Update metadata from latest chunk (headers/etc.) chunk_metadata = extract_portkey_unit_metadata(chunk, model_name) if chunk_metadata: latest_chunk_metadata.update(chunk_metadata) @@ -487,7 +487,7 @@ def extract_portkey_metadata(client: "Portkey") -> Dict[str, Any]: def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]: - """Extract metadata from a response or chunk unit (headers, cost, ids).""" + """Extract metadata from a response or chunk unit (headers, ids).""" metadata: Dict[str, Any] = {} try: # Extract system fingerprint if available (OpenAI-compatible) @@ -512,16 +512,12 @@ def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]: lower = {k.lower(): v for k, v in headers.items()} if "x-portkey-trace-id" in lower: metadata["portkey_trace_id"] = lower["x-portkey-trace-id"] - if "x-portkey-provider" in lower: - metadata["provider"] = lower["x-portkey-provider"] if "x-portkey-cache-status" in lower: metadata["portkey_cache_status"] = lower["x-portkey-cache-status"] - # Cost if surfaced by gateway - if "x-portkey-cost" in lower: - try: - metadata["cost"] = float(lower["x-portkey-cost"]) - except Exception: - pass + if "x-portkey-retry-attempt-count" in lower: + metadata["portkey_retry_attempt_count"] = lower["x-portkey-retry-attempt-count"] + if "x-portkey-last-used-option-index" in lower: + metadata["portkey_last_used_option_index"] = lower["x-portkey-last-used-option-index"] except Exception: pass # Attach model for convenience @@ -662,7 +658,7 @@ def calculate_streaming_usage_and_cost( total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - # Cost from headers if present; otherwise simple heuristic for some models + # Cost from metadata if present; otherwise simple heuristic for some models cost = latest_chunk_metadata.get("cost") if cost is None and total_tokens and model_name: ml = model_name.lower() @@ -678,7 +674,7 @@ def calculate_streaming_usage_and_cost( def _extract_provider_from_object(obj: Any) -> Optional[str]: """Extract provider from a response or chunk object. - Checks response_metadata and _response_headers for provider information. + Checks response_metadata for provider information. Returns None if no provider is found. """ try: @@ -686,13 +682,6 @@ def _extract_provider_from_object(obj: Any) -> Optional[str]: if hasattr(obj, "response_metadata") and _is_dict_like(obj.response_metadata): if "provider" in obj.response_metadata: return obj.response_metadata["provider"] - # Check _response_headers - if hasattr(obj, "_response_headers"): - headers = getattr(obj, "_response_headers") - if _is_dict_like(headers): - for k, v in headers.items(): - if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: - return str(v) except Exception: pass return None diff --git a/tests/test_portkey_integration.py b/tests/test_portkey_integration.py index c9d2de5e..9b5e23c5 100644 --- a/tests/test_portkey_integration.py +++ b/tests/test_portkey_integration.py @@ -145,16 +145,16 @@ def test_extract_portkey_metadata(self) -> None: assert "Authorization" not in metadata["portkeyHeaders"] def test_extract_portkey_unit_metadata(self) -> None: - """Unit metadata should capture headers, cost, and provider hints.""" + """Unit metadata should capture headers and retry/option index hints.""" from openlayer.lib.integrations.portkey_tracer import extract_portkey_unit_metadata unit = SimpleNamespace( system_fingerprint="fingerprint-123", _response_headers={ "x-portkey-trace-id": "trace-1", - "x-portkey-provider": "anthropic", "x-portkey-cache-status": "HIT", - "x-portkey-cost": "0.45", + "x-portkey-retry-attempt-count": "2", + "x-portkey-last-used-option-index": "config.targets[1]", "content-type": "application/json", }, ) @@ -163,9 +163,9 @@ def test_extract_portkey_unit_metadata(self) -> None: assert metadata["system_fingerprint"] == "fingerprint-123" assert metadata["portkey_trace_id"] == "trace-1" - assert metadata["provider"] == "anthropic" assert metadata["portkey_cache_status"] == "HIT" - assert metadata["cost"] == pytest.approx(0.45) + assert metadata["portkey_retry_attempt_count"] == "2" + assert metadata["portkey_last_used_option_index"] == "config.targets[1]" assert metadata["portkey_model"] == "claude-3-opus" assert metadata["response_headers"]["content-type"] == "application/json" @@ -178,8 +178,9 @@ class DictLikeHeaders: def __init__(self): self._data = { "x-portkey-trace-id": "trace-2", - "x-portkey-provider": "openai", "x-portkey-cache-status": "MISS", + "x-portkey-retry-attempt-count": "3", + "x-portkey-last-used-option-index": "config.targets[1]", } def items(self): @@ -192,8 +193,9 @@ def items(self): metadata = extract_portkey_unit_metadata(unit, "gpt-4") assert metadata["portkey_trace_id"] == "trace-2" - assert metadata["provider"] == "openai" assert metadata["portkey_cache_status"] == "MISS" + assert metadata["portkey_retry_attempt_count"] == "3" + assert metadata["portkey_last_used_option_index"] == "config.targets[1]" def test_extract_usage_from_response(self) -> None: """Usage extraction should read OpenAI-style usage objects.""" @@ -318,9 +320,8 @@ def test_detect_provider_from_response_fallback(self) -> None: """Provider detection should fall back to response metadata or model name.""" from openlayer.lib.integrations.portkey_tracer import detect_provider - client = SimpleNamespace(headers={"x-portkey-provider": "openai"}) + client = SimpleNamespace() response = SimpleNamespace( - _response_headers={"X-Portkey-Provider": "anthropic"}, response_metadata={"provider": "anthropic"}, ) @@ -335,7 +336,6 @@ def test_detect_provider_from_chunk_fallback(self) -> None: chunk = SimpleNamespace( response_metadata={"provider": "cohere"}, - _response_headers={"X-Portkey-Provider": "cohere"}, ) client = SimpleNamespace() From 3c66d8b41b0b5262b93327d1e1e01d070324c9f0 Mon Sep 17 00:00:00 2001 From: Priyank <15610225+priyankinfinnov@users.noreply.github.com> Date: Thu, 13 Nov 2025 11:29:07 +0530 Subject: [PATCH 7/7] removed unused code --- .../lib/integrations/portkey_tracer.py | 22 ------------------- tests/test_portkey_integration.py | 9 -------- 2 files changed, 31 deletions(-) diff --git a/src/openlayer/lib/integrations/portkey_tracer.py b/src/openlayer/lib/integrations/portkey_tracer.py index a72dd770..f9c6ca9e 100644 --- a/src/openlayer/lib/integrations/portkey_tracer.py +++ b/src/openlayer/lib/integrations/portkey_tracer.py @@ -532,7 +532,6 @@ def extract_usage(obj: Any) -> Dict[str, Optional[int]]: This function attempts to extract token usage information from various locations where it might be stored, including: - Direct `usage` attribute - - `_hidden_params` (for streaming chunks) - `model_dump()` dictionary (for streaming chunks) Parameters @@ -556,27 +555,6 @@ def extract_usage(obj: Any) -> Dict[str, Optional[int]]: "completion_tokens": getattr(usage, "completion_tokens", None), } - # Check for usage in _hidden_params (primarily for streaming chunks) - if hasattr(obj, "_hidden_params"): - hidden_params = obj._hidden_params - # Check if usage is a direct attribute - if hasattr(hidden_params, "usage") and hidden_params.usage is not None: - usage = hidden_params.usage - return { - "total_tokens": getattr(usage, "total_tokens", None), - "prompt_tokens": getattr(usage, "prompt_tokens", None), - "completion_tokens": getattr(usage, "completion_tokens", None), - } - # Check if usage is a dictionary key - elif _supports_membership_check(hidden_params) and "usage" in hidden_params: - usage = hidden_params["usage"] - if usage: - return { - "total_tokens": usage.get("total_tokens", None), - "prompt_tokens": usage.get("prompt_tokens", None), - "completion_tokens": usage.get("completion_tokens", None), - } - # Check if object model dump has usage (primarily for streaming chunks) if hasattr(obj, "model_dump"): obj_dict = obj.model_dump() diff --git a/tests/test_portkey_integration.py b/tests/test_portkey_integration.py index 9b5e23c5..97e2d50e 100644 --- a/tests/test_portkey_integration.py +++ b/tests/test_portkey_integration.py @@ -230,15 +230,6 @@ def test_extract_usage_from_chunk(self) -> None: "completion_tokens": 80, } - chunk_hidden = SimpleNamespace( - _hidden_params={"usage": {"total_tokens": 30, "prompt_tokens": 10, "completion_tokens": 20}} - ) - assert extract_usage(chunk_hidden) == { - "total_tokens": 30, - "prompt_tokens": 10, - "completion_tokens": 20, - } - class ChunkWithModelDump: # pylint: disable=too-few-public-methods def model_dump(self) -> Dict[str, Any]: return {"usage": {"total_tokens": 12, "prompt_tokens": 5, "completion_tokens": 7}}