diff --git a/examples/tracing/portkey/portkey_tracing.ipynb b/examples/tracing/portkey/portkey_tracing.ipynb new file mode 100644 index 00000000..4231b462 --- /dev/null +++ b/examples/tracing/portkey/portkey_tracing.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openlayer-ai/openlayer-python/blob/main/examples/tracing/portkey/portkey_tracing.ipynb)\n", + "\n", + "\n", + "# Portkey monitoring quickstart\n", + "\n", + "This notebook illustrates how to get started monitoring Portkey completions with Openlayer.\n", + "\n", + "Portkey provides a unified interface to call 100+ LLM APIs using the same input/output format. This integration allows you to trace and monitor completions across all supported providers through a single interface.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install openlayer portkey-ai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Set the environment variables\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from portkey_ai import Portkey\n", + "\n", + "\n", + "# Set your Portkey API keys\n", + "os.environ['PORTKEY_API_KEY'] = \"YOUR_PORTKEY_API_HERE\"\n", + "\n", + "# Openlayer env variables\n", + "os.environ[\"OPENLAYER_API_KEY\"] = \"YOUR_OPENLAYER_API_KEY_HERE\"\n", + "os.environ[\"OPENLAYER_INFERENCE_PIPELINE_ID\"] = \"YOUR_OPENLAYER_INFERENCE_PIPELINE_ID_HERE\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Enable Portkey tracing\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from openlayer.lib import trace_portkey\n", + "\n", + "# Enable openlayer tracing for all Portkey completions\n", + "trace_portkey()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Use Portkey normally - tracing happens automatically!\n", + "\n", + "### Basic completion with OpenAI\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Basic portkey client initialization\n", + "portkey = Portkey(\n", + " api_key = os.environ['PORTKEY_API_KEY'],\n", + " config = \"YOUR_PORTKEY_CONFIG_ID_HERE\", # optional your portkey config id\n", + ")\n", + "\n", + "# Basic portkey LLM call\n", + "response = portkey.chat.completions.create(\n", + " #model = \"@YOUR_PORTKEY_SLUG/YOUR_MODEL_NAME\", # optional if giving config\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"Write a poem on Argentina, least 500 words.\"}\n", + " ]\n", + ")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. View your traces\n", + "\n", + "Once you've run the examples above, you can:\n", + "\n", + "1. **Visit your OpenLayer dashboard** to see all the traced completions\n", + "2. **Analyze performance** across different models and providers\n", + "3. **Monitor costs** and token usage\n", + "4. **Debug issues** with detailed request/response logs\n", + "5. **Compare models** side-by-side\n", + "\n", + "The traces will include:\n", + "- **Request details**: Model, parameters, messages\n", + "- **Response data**: Generated content, token counts, latency\n", + "- **Provider information**: Which underlying service was used\n", + "- **Custom metadata**: Any additional context you provide\n", + "\n", + "For more information, check out:\n", + "- [OpenLayer Documentation](https://docs.openlayer.com/)\n", + "- [Portkey Documentation](https://portkey.ai/docs)\n", + "- [Portkey AI Gateway](https://portkey.ai/docs/product/ai-gateway)\n", + "- [Portkey Supported Providers](https://portkey.ai/docs/api-reference/inference-api/supported-providers)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/openlayer/lib/__init__.py b/src/openlayer/lib/__init__.py index 0be95646..53cb9c24 100644 --- a/src/openlayer/lib/__init__.py +++ b/src/openlayer/lib/__init__.py @@ -14,6 +14,7 @@ "trace_oci_genai", "trace_oci", # Alias for backward compatibility "trace_litellm", + "trace_portkey", "trace_google_adk", "unpatch_google_adk", "update_current_trace", @@ -192,6 +193,41 @@ def trace_litellm(): return litellm_tracer.trace_litellm() +# ---------------------------------- Portkey ---------------------------------- # +def trace_portkey(): + """Enable tracing for Portkey completions. + + This function patches Portkey's chat.completions.create to automatically trace + all OpenAI-compatible completions routed via the Portkey AI Gateway. + + Example: + >>> from portkey_ai import Portkey + >>> from openlayer.lib import trace_portkey + >>> # Enable openlayer tracing for all Portkey completions + >>> trace_portkey() + >>> # Basic portkey client initialization + >>> portkey = Portkey( + >>> api_key = os.environ['PORTKEY_API_KEY'], + >>> config = "YOUR_PORTKEY_CONFIG_ID", # optional your portkey config id + >>> ) + >>> # use portkey normally - tracing happens automatically + >>> response = portkey.chat.completions.create( + >>> #model = "@YOUR_PORTKEY_SLUG/YOUR_MODEL_NAME", # optional if giving config + >>> messages = [ + >>> {"role": "system", "content": "You are a helpful assistant."}, + >>> {"role": "user", "content": "Write a poem on Argentina, least 100 words."} + >>> ] + >>> ) + """ + # pylint: disable=import-outside-toplevel + try: + from portkey_ai import Portkey # noqa: F401 + except ImportError: + raise ImportError("portkey-ai is required for Portkey tracing. Install with: pip install portkey-ai") + + from .integrations import portkey_tracer + + return portkey_tracer.trace_portkey() # ------------------------------ Google ADK ---------------------------------- # def trace_google_adk(): """Enable tracing for Google Agent Development Kit (ADK). diff --git a/src/openlayer/lib/integrations/portkey_tracer.py b/src/openlayer/lib/integrations/portkey_tracer.py new file mode 100644 index 00000000..f9c6ca9e --- /dev/null +++ b/src/openlayer/lib/integrations/portkey_tracer.py @@ -0,0 +1,756 @@ +"""Module with methods used to trace Portkey AI Gateway chat completions.""" + +import json +import logging +import time +from functools import wraps +from typing import Any, Dict, Iterator, Optional, Union, TYPE_CHECKING + +try: + from portkey_ai import Portkey + HAVE_PORTKEY = True +except ImportError: + HAVE_PORTKEY = False + +if TYPE_CHECKING: + from portkey_ai import Portkey + +from ..tracing import tracer + +logger = logging.getLogger(__name__) + + +def trace_portkey() -> None: + """Patch Portkey's chat.completions.create to trace completions. + + The following information is collected for each completion: + - start_time: The time when the completion was requested. + - end_time: The time when the completion was received. + - latency: The time it took to generate the completion. + - tokens: The total number of tokens used to generate the completion. + - prompt_tokens: The number of tokens in the prompt. + - completion_tokens: The number of tokens in the completion. + - model: The model used to generate the completion. + - model_parameters: The parameters used to configure the model. + - raw_output: The raw output of the model. + - inputs: The inputs used to generate the completion. + - Portkey-specific metadata (base URL, x-portkey-* headers if available) + + Returns + ------- + None + This function patches portkey.chat.completions.create in place. + + Example + ------- + >>> from portkey_ai import Portkey + >>> from openlayer.lib import trace_portkey + >>> + >>> # Enable tracing + >>> trace_portkey() + >>> + >>> # Use Portkey normally - tracing happens automatically + >>> portkey = Portkey(api_key = "YOUR_PORTKEY_API_KEY") + >>> response = portkey.chat.completions.create( + ... model = "@YOUR_PROVIDER_SLUG/MODEL_NAME", + ... messages = [ + ... {"role": "system", "content": "You are a helpful assistant."}, + ... {"role": "user", "content": "What is Portkey"} + ... ], + ... inference_id="custom-id-123" # Optional Openlayer parameter + ... max_tokens = 512 + ... ) + """ + if not HAVE_PORTKEY: + raise ImportError( + "Portkey library is not installed. Please install it with: pip install portkey-ai" + ) + + # Patch instances on initialization rather than class-level attributes. + # Some SDKs initialize 'chat' lazily on the instance. + original_init = Portkey.__init__ + + @wraps(original_init) + def traced_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + try: + # Avoid double-patching + if getattr(self, "_openlayer_portkey_patched", False): + return + # Access chat to ensure it's constructed, then wrap create + chat = getattr(self, "chat", None) + if chat is None or not hasattr(chat, "completions") or not hasattr(chat.completions, "create"): + # If the structure isn't present, skip gracefully and log diagnostics + logger.debug( + "Openlayer Portkey tracer: Portkey client missing expected attributes (chat/completions/create). " + "Tracing not applied for this instance." + ) + return + original_create = chat.completions.create + + @wraps(original_create) + def traced_create(*c_args, **c_kwargs): + inference_id = c_kwargs.pop("inference_id", None) + stream = c_kwargs.get("stream", False) + if stream: + return handle_streaming_create( + self, + *c_args, + create_func=original_create, + inference_id=inference_id, + **c_kwargs, + ) + return handle_non_streaming_create( + self, + *c_args, + create_func=original_create, + inference_id=inference_id, + **c_kwargs, + ) + + self.chat.completions.create = traced_create + setattr(self, "_openlayer_portkey_patched", True) + logger.debug("Openlayer Portkey tracer: successfully patched Portkey client instance for tracing.") + except Exception as e: + logger.debug("Failed to patch Portkey client instance for tracing: %s", e) + + Portkey.__init__ = traced_init + logger.info("Openlayer Portkey tracer: tracing enabled (instance-level patch).") + + +def handle_streaming_create( + client: "Portkey", + *args, + create_func: callable, + inference_id: Optional[str] = None, + **kwargs, +) -> Iterator[Any]: + """ + Handles streaming chat.completions.create routed via Portkey. + + Parameters + ---------- + client : Portkey + The Portkey client instance making the request. + *args : + Positional arguments passed to the create function. + create_func : callable + The create function to call (typically chat.completions.create). + inference_id : Optional[str], default None + Optional inference ID for tracking this request. + **kwargs : + Additional keyword arguments forwarded to create_func. + + Returns + ------- + Iterator[Any] + A generator that yields the chunks of the completion. + """ + # Portkey is OpenAI-compatible; request and chunks follow OpenAI spec + # create_func is a bound method; do not pass client again + chunks = create_func(*args, **kwargs) + return stream_chunks( + chunks=chunks, + kwargs=kwargs, + client=client, + inference_id=inference_id, + ) + + +def stream_chunks( + chunks: Iterator[Any], + kwargs: Dict[str, Any], + client: "Portkey", + inference_id: Optional[str] = None, +): + """Streams the chunks of the completion and traces the completion.""" + collected_output_data = [] + collected_function_call = {"name": "", "arguments": ""} + raw_outputs = [] + start_time = time.time() + end_time = None + first_token_time = None + num_of_completion_tokens = None + latency = None + model_name = kwargs.get("model", "unknown") + provider = "unknown" + latest_usage_data = {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None} + latest_chunk_metadata: Dict[str, Any] = {} + + try: + i = 0 + for i, chunk in enumerate(chunks): + raw_outputs.append(chunk.model_dump() if hasattr(chunk, "model_dump") else str(chunk)) + + if i == 0: + first_token_time = time.time() + # Try to detect provider at first chunk + provider = detect_provider(chunk, client, model_name) + if i > 0: + num_of_completion_tokens = i + 1 + + # Extract usage from chunk if available + chunk_usage = extract_usage(chunk) + if any(v is not None for v in chunk_usage.values()): + latest_usage_data = chunk_usage + + # Update metadata from latest chunk (headers/etc.) + chunk_metadata = extract_portkey_unit_metadata(chunk, model_name) + if chunk_metadata: + latest_chunk_metadata.update(chunk_metadata) + + # Extract delta from chunk (OpenAI-compatible) + delta = get_delta_from_chunk(chunk) + + if delta and getattr(delta, "content", None): + collected_output_data.append(delta.content) + elif delta and getattr(delta, "function_call", None): + if delta.function_call.name: + collected_function_call["name"] += delta.function_call.name + if delta.function_call.arguments: + collected_function_call["arguments"] += delta.function_call.arguments + elif delta and getattr(delta, "tool_calls", None): + tool_call = delta.tool_calls[0] + if getattr(tool_call.function, "name", None): + collected_function_call["name"] += tool_call.function.name + if getattr(tool_call.function, "arguments", None): + collected_function_call["arguments"] += tool_call.function.arguments + + yield chunk + end_time = time.time() + latency = (end_time - start_time) * 1000 + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed to yield Portkey chunk. %s", e) + finally: + # Try to add step to the trace + try: + collected_output_data = [m for m in collected_output_data if m is not None] + if collected_output_data: + output_data = "".join(collected_output_data) + else: + if collected_function_call["arguments"]: + try: + collected_function_call["arguments"] = json.loads(collected_function_call["arguments"]) + except json.JSONDecodeError: + pass + output_data = collected_function_call + + # Calculate usage and cost at end of stream (prioritize actual usage if present) + completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost( + chunks=raw_outputs, + messages=kwargs.get("messages", []), + output_content=output_data, + model_name=model_name, + latest_usage_data=latest_usage_data, + latest_chunk_metadata=latest_chunk_metadata, + ) + + usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {} + final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0) + final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens) + final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", (final_prompt_tokens or 0) + (final_completion_tokens or 0)) + final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get("cost", None) + + trace_args = create_trace_args( + end_time=end_time, + inputs={"prompt": kwargs.get("messages", [])}, + output=output_data, + latency=latency, + tokens=final_total_tokens, + prompt_tokens=final_prompt_tokens, + completion_tokens=final_completion_tokens, + model=model_name, + model_parameters=get_model_parameters(kwargs), + raw_output=raw_outputs, + id=inference_id, + cost=final_cost, + metadata={ + "timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None), + "provider": provider, + "portkey_model": model_name, + **extract_portkey_metadata(client), + **latest_chunk_metadata, + }, + ) + add_to_trace(**trace_args) + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed to trace the Portkey streaming completion. %s", e) + + +def handle_non_streaming_create( + client: "Portkey", + *args, + create_func: callable, + inference_id: Optional[str] = None, + **kwargs, +) -> Any: + """ + Handles non-streaming chat.completions.create routed via Portkey. + + Parameters + ---------- + client : Portkey + The Portkey client instance used for routing the request. + *args : + Positional arguments for the create function. + create_func : callable + The function used to create the chat completion. This is a bound method, so do not pass client again. + inference_id : Optional[str], optional + A unique identifier for the inference or trace, by default None. + **kwargs : + Additional keyword arguments passed to the create function (e.g., "messages", "model", etc.). + + Returns + ------- + Any + The completion response as returned by the create function. + """ + start_time = time.time() + # create_func is a bound method; do not pass client again + response = create_func(*args, **kwargs) + end_time = time.time() + + # Try to add step to the trace + try: + output_data = parse_non_streaming_output_data(response) + + # Usage (if provided by upstream provider via Portkey) + usage_data = extract_usage(response) + model_name = getattr(response, "model", kwargs.get("model", "unknown")) + provider = detect_provider(response, client, model_name) + extra_metadata = extract_portkey_unit_metadata(response, model_name) + cost = extra_metadata.get("cost", None) + + trace_args = create_trace_args( + end_time=end_time, + inputs={"prompt": kwargs.get("messages", [])}, + output=output_data, + latency=(end_time - start_time) * 1000, + tokens=usage_data.get("total_tokens"), + prompt_tokens=usage_data.get("prompt_tokens"), + completion_tokens=usage_data.get("completion_tokens"), + model=model_name, + model_parameters=get_model_parameters(kwargs), + raw_output=response.model_dump() if hasattr(response, "model_dump") else str(response), + id=inference_id, + cost=cost, + metadata={ + "system_fingerprint": getattr(response, "system_fingerprint", None), + "provider": provider, + "portkey_model": model_name, + **extract_portkey_metadata(client), + **extra_metadata, + }, + ) + add_to_trace(**trace_args) + # pylint: disable=broad-except + except Exception as e: + logger.error("Failed to trace the Portkey non-streaming completion. %s", e) + + return response + + +def get_model_parameters(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Gets the model parameters from the kwargs (OpenAI-compatible).""" + return { + "temperature": kwargs.get("temperature", 1), + "top_p": kwargs.get("top_p", 1), + "max_tokens": kwargs.get("max_tokens", None), + "n": kwargs.get("n", 1), + "stream": kwargs.get("stream", False), + "stop": kwargs.get("stop", None), + "presence_penalty": kwargs.get("presence_penalty", 0), + "frequency_penalty": kwargs.get("frequency_penalty", 0), + "logit_bias": kwargs.get("logit_bias", None), + "logprobs": kwargs.get("logprobs", False), + "top_logprobs": kwargs.get("top_logprobs", None), + "parallel_tool_calls": kwargs.get("parallel_tool_calls", True), + "seed": kwargs.get("seed", None), + "response_format": kwargs.get("response_format", None), + "timeout": kwargs.get("timeout", None), + "api_base": kwargs.get("api_base", None), + "api_version": kwargs.get("api_version", None), + } + + +def create_trace_args( + end_time: float, + inputs: Dict[str, Any], + output: Union[str, Dict[str, Any], None], + latency: float, + tokens: Optional[int], + prompt_tokens: Optional[int], + completion_tokens: Optional[int], + model: str, + model_parameters: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + raw_output: Optional[Union[str, Dict[str, Any]]] = None, + id: Optional[str] = None, + cost: Optional[float] = None, +) -> Dict[str, Any]: + """Returns a dictionary with the trace arguments.""" + trace_args = { + "end_time": end_time, + "inputs": inputs, + "output": output, + "latency": latency, + "tokens": tokens, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "model": model, + "model_parameters": model_parameters, + "raw_output": raw_output, + "metadata": metadata if metadata else {}, + } + if id: + trace_args["id"] = id + if cost is not None: + trace_args["cost"] = cost + return trace_args + + +def add_to_trace(**kwargs) -> None: + """Add a chat completion step to the trace.""" + provider = kwargs.get("metadata", {}).get("provider", "Portkey") + tracer.add_chat_completion_step_to_trace(**kwargs, name="Portkey Chat Completion", provider=provider) + + +def parse_non_streaming_output_data(response: Any) -> Union[str, Dict[str, Any], None]: + """Parses the output data from a non-streaming completion (OpenAI-compatible).""" + try: + if hasattr(response, "choices") and response.choices: + choice = response.choices[0] + message = getattr(choice, "message", None) + if message is None: + return None + content = getattr(message, "content", None) + function_call = getattr(message, "function_call", None) + tool_calls = getattr(message, "tool_calls", None) + if content: + return content.strip() + if function_call: + return { + "name": function_call.name, + "arguments": json.loads(function_call.arguments) if isinstance(function_call.arguments, str) else function_call.arguments, + } + if tool_calls: + return { + "name": tool_calls[0].function.name, + "arguments": json.loads(tool_calls[0].function.arguments) if isinstance(tool_calls[0].function.arguments, str) else tool_calls[0].function.arguments, + } + except Exception as e: + logger.debug("Error parsing Portkey output data: %s", e) + return None + + +def extract_portkey_metadata(client: "Portkey") -> Dict[str, Any]: + """Extract Portkey-specific metadata from a Portkey client. + + Attempts to read base URL and redact x-portkey-* headers if present. + Works defensively across SDK versions. + """ + metadata: Dict[str, Any] = {"isPortkey": True} + # Base URL or host + for attr in ("base_url", "baseURL", "host", "custom_host"): + try: + val = getattr(client, attr, None) + if val: + metadata["portkeyBaseUrl"] = str(val) + break + except Exception: + continue + + # Headers + possible_header_attrs = ("default_headers", "headers", "_default_headers", "_headers", "custom_headers", "allHeaders") + redacted: Dict[str, Any] = {} + for attr in possible_header_attrs: + try: + headers = getattr(client, attr, None) + if _is_dict_like(headers): + for k, v in headers.items(): + if isinstance(k, str) and k.lower().startswith("x-portkey-"): + if k.lower() in {"x-portkey-api-key", "x-portkey-virtual-key"}: + redacted[k] = "***" + else: + redacted[k] = v + except Exception: + continue + if redacted: + metadata["portkeyHeaders"] = redacted + else: + logger.debug( + "Openlayer Portkey tracer: No x-portkey-* headers detected on client; provider/config metadata may be limited." + ) + return metadata + + +def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]: + """Extract metadata from a response or chunk unit (headers, ids).""" + metadata: Dict[str, Any] = {} + try: + # Extract system fingerprint if available (OpenAI-compatible) + if hasattr(unit, "system_fingerprint"): + metadata["system_fingerprint"] = unit.system_fingerprint + if hasattr(unit, "service_tier"): + metadata["service_tier"] = unit.service_tier + + # Response headers may be present on the object + headers_obj = None + if hasattr(unit, "_response_headers"): + headers_obj = getattr(unit, "_response_headers") + elif hasattr(unit, "response_headers"): + headers_obj = getattr(unit, "response_headers") + elif hasattr(unit, "_headers"): + headers_obj = getattr(unit, "_headers") + + if _is_dict_like(headers_obj): + headers = {str(k): v for k, v in headers_obj.items()} + metadata["response_headers"] = headers + # Known Portkey header hints (names are lower-cased defensively) + lower = {k.lower(): v for k, v in headers.items()} + if "x-portkey-trace-id" in lower: + metadata["portkey_trace_id"] = lower["x-portkey-trace-id"] + if "x-portkey-cache-status" in lower: + metadata["portkey_cache_status"] = lower["x-portkey-cache-status"] + if "x-portkey-retry-attempt-count" in lower: + metadata["portkey_retry_attempt_count"] = lower["x-portkey-retry-attempt-count"] + if "x-portkey-last-used-option-index" in lower: + metadata["portkey_last_used_option_index"] = lower["x-portkey-last-used-option-index"] + except Exception: + pass + # Attach model for convenience + if model_name: + metadata["portkey_model"] = model_name + return metadata + + +def extract_usage(obj: Any) -> Dict[str, Optional[int]]: + """Extract usage from a response or chunk object. + + This function attempts to extract token usage information from various + locations where it might be stored, including: + - Direct `usage` attribute + - `model_dump()` dictionary (for streaming chunks) + + Parameters + ---------- + obj : Any + The response or chunk object to extract usage from. + + Returns + ------- + Dict[str, Optional[int]] + Dictionary with keys: total_tokens, prompt_tokens, completion_tokens. + Values are None if usage information is not found. + """ + try: + # Check for direct usage attribute (works for both response and chunk) + if hasattr(obj, "usage") and obj.usage is not None: + usage = obj.usage + return { + "total_tokens": getattr(usage, "total_tokens", None), + "prompt_tokens": getattr(usage, "prompt_tokens", None), + "completion_tokens": getattr(usage, "completion_tokens", None), + } + + # Check if object model dump has usage (primarily for streaming chunks) + if hasattr(obj, "model_dump"): + obj_dict = obj.model_dump() + if _supports_membership_check(obj_dict) and "usage" in obj_dict and obj_dict["usage"]: + usage = obj_dict["usage"] + return { + "total_tokens": usage.get("total_tokens", None), + "prompt_tokens": usage.get("prompt_tokens", None), + "completion_tokens": usage.get("completion_tokens", None), + } + except Exception: + pass + return {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None} + + +def calculate_streaming_usage_and_cost( + chunks: Any, + messages: Any, + output_content: Any, + model_name: str, + latest_usage_data: Dict[str, Optional[int]], + latest_chunk_metadata: Dict[str, Any], +): + """Calculate usage and cost at the end of streaming.""" + try: + # Priority 1: Actual usage provided in chunks + if latest_usage_data and latest_usage_data.get("total_tokens") and latest_usage_data.get("total_tokens") > 0: + return ( + latest_usage_data.get("completion_tokens"), + latest_usage_data.get("prompt_tokens"), + latest_usage_data.get("total_tokens"), + latest_chunk_metadata.get("cost"), + ) + + # Priority 2: Look for usage embedded in final chunk dicts (if raw dicts) + if isinstance(chunks, list): + for chunk_data in reversed(chunks): + if _supports_membership_check(chunk_data) and "usage" in chunk_data and chunk_data["usage"]: + usage = chunk_data["usage"] + if usage.get("total_tokens", 0) > 0: + return ( + usage.get("completion_tokens"), + usage.get("prompt_tokens"), + usage.get("total_tokens"), + latest_chunk_metadata.get("cost"), + ) + + # Priority 3: Estimate tokens + completion_tokens = None + prompt_tokens = None + total_tokens = None + cost = None + + # Estimate completion tokens + if isinstance(output_content, str): + completion_tokens = max(1, len(output_content) // 4) + elif _is_dict_like(output_content): + json_str = json.dumps(output_content) if output_content else "{}" + completion_tokens = max(1, len(json_str) // 4) + else: + # Fallback: count chunks present + try: + completion_tokens = len([c for c in chunks if c]) + except Exception: + completion_tokens = None + + # Estimate prompt tokens from messages + if messages: + total_chars = 0 + try: + for message in messages: + if _supports_membership_check(message) and "content" in message: + total_chars += len(str(message["content"])) + except Exception: + total_chars = 0 + prompt_tokens = max(1, total_chars // 4) if total_chars > 0 else 0 + else: + prompt_tokens = 0 + + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + + # Cost from metadata if present; otherwise simple heuristic for some models + cost = latest_chunk_metadata.get("cost") + if cost is None and total_tokens and model_name: + ml = model_name.lower() + if "gpt-3.5-turbo" in ml: + cost = (prompt_tokens * 0.0005 / 1000.0) + (completion_tokens * 0.0015 / 1000.0) + + return completion_tokens, prompt_tokens, total_tokens, cost + except Exception as e: + logger.error("Error calculating streaming usage: %s", e) + return None, None, None, None + + +def _extract_provider_from_object(obj: Any) -> Optional[str]: + """Extract provider from a response or chunk object. + + Checks response_metadata for provider information. + Returns None if no provider is found. + """ + try: + # Check response_metadata + if hasattr(obj, "response_metadata") and _is_dict_like(obj.response_metadata): + if "provider" in obj.response_metadata: + return obj.response_metadata["provider"] + except Exception: + pass + return None + + +def detect_provider(obj: Any, client: "Portkey", model_name: str) -> str: + """Detect provider from a response or chunk object. + + Parameters + ---------- + obj : Any + The response or chunk object to extract provider information from. + client : Portkey + The Portkey client instance. + model_name : str + The model name to use as a fallback for provider detection. + + Returns + ------- + str + The detected provider name. + """ + # First: check Portkey headers on the client (authoritative) + provider = _provider_from_portkey_headers(client) + if provider: + return provider + # Next: check object metadata if any + provider = _extract_provider_from_object(obj) + if provider: + return provider + # Fallback to model name heuristics + return detect_provider_from_model_name(model_name) + + +def detect_provider_from_model_name(model_name: str) -> str: + """Detect provider from model name patterns.""" + model_lower = (model_name or "").lower() + if model_lower.startswith(("gpt-", "o1-", "text-davinci", "text-curie", "text-babbage", "text-ada")): + return "OpenAI" + if model_lower.startswith(("claude-", "claude")): + return "Anthropic" + if "gemini" in model_lower or "palm" in model_lower: + return "Google" + if "llama" in model_lower or "meta-" in model_lower: + return "Meta" + if model_lower.startswith("mistral") or "mixtral" in model_lower: + return "Mistral" + if model_lower.startswith("command"): + return "Cohere" + return "Portkey" + + +def get_delta_from_chunk(chunk: Any) -> Any: + """Extract delta from chunk, handling different response formats.""" + try: + if hasattr(chunk, "choices") and chunk.choices: + choice = chunk.choices[0] + if hasattr(choice, "delta"): + return choice.delta + except Exception: + pass + return None + + +def _provider_from_portkey_headers(client: "Portkey") -> Optional[str]: + """Get provider from Portkey headers on the client.""" + header_sources = ("default_headers", "headers", "_default_headers", "_headers") + for attr in header_sources: + try: + headers = getattr(client, attr, None) + if _is_dict_like(headers): + for k, v in headers.items(): + if isinstance(k, str) and k.lower() == "x-portkey-provider" and v: + return str(v) + except Exception: + continue + return None + + +def _is_dict_like(obj: Any) -> bool: + """Check if an object is dict-like (has .items() method). + + This is more robust than isinstance(obj, dict) as it handles + custom dict-like objects (e.g., CaseInsensitiveDict, custom headers). + """ + return hasattr(obj, "items") and callable(getattr(obj, "items", None)) + + +def _supports_membership_check(obj: Any) -> bool: + """Check if an object supports membership testing (e.g., 'key in obj'). + + This checks for __contains__ method or if it's dict-like. + """ + return hasattr(obj, "__contains__") or _is_dict_like(obj) diff --git a/tests/test_integration_conditional_imports.py b/tests/test_integration_conditional_imports.py index f673b480..adb5bf9d 100644 --- a/tests/test_integration_conditional_imports.py +++ b/tests/test_integration_conditional_imports.py @@ -34,6 +34,7 @@ "oci_tracer": ["oci"], "langchain_callback": ["langchain", "langchain_core", "langchain_community"], "litellm_tracer": ["litellm"], + "portkey_tracer": ["portkey_ai"], } # Expected patterns for integration modules diff --git a/tests/test_portkey_integration.py b/tests/test_portkey_integration.py new file mode 100644 index 00000000..97e2d50e --- /dev/null +++ b/tests/test_portkey_integration.py @@ -0,0 +1,576 @@ +"""Test Portkey tracer integration.""" + +import json +from types import SimpleNamespace +from typing import Any, Dict +from unittest.mock import Mock, patch + +import pytest # type: ignore + + +class TestPortkeyIntegration: + """Test Portkey integration functionality.""" + + def test_import_without_portkey(self) -> None: + """Module should import even when Portkey is unavailable.""" + from openlayer.lib.integrations import portkey_tracer # noqa: F401 + + assert hasattr(portkey_tracer, "HAVE_PORTKEY") + + def test_trace_portkey_raises_import_error_without_dependency(self) -> None: + """trace_portkey should raise ImportError when Portkey is missing.""" + with patch("openlayer.lib.integrations.portkey_tracer.HAVE_PORTKEY", False): + from openlayer.lib.integrations.portkey_tracer import trace_portkey + + with pytest.raises(ImportError) as exc_info: # type: ignore + trace_portkey() + + message = str(exc_info.value) # type: ignore[attr-defined] + assert "Portkey library is not installed" in message + assert "pip install portkey-ai" in message + + def test_trace_portkey_patches_portkey_client(self) -> None: + """trace_portkey should wrap Portkey chat completions for tracing.""" + + class DummyPortkey: # pylint: disable=too-few-public-methods + """Lightweight Portkey stand-in used for patching behavior.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401 + completions = SimpleNamespace(create=Mock(name="original_create")) + self.chat = SimpleNamespace(completions=completions) + self._init_args = (args, kwargs) + self.original_create = completions.create + + with patch("openlayer.lib.integrations.portkey_tracer.HAVE_PORTKEY", True), patch( + "openlayer.lib.integrations.portkey_tracer.Portkey", DummyPortkey, create=True + ), patch( + "openlayer.lib.integrations.portkey_tracer.handle_non_streaming_create", + autospec=True, + ) as mock_non_streaming, patch( + "openlayer.lib.integrations.portkey_tracer.handle_streaming_create", + autospec=True, + ) as mock_streaming: + mock_non_streaming.return_value = "non-stream-result" + mock_streaming.return_value = "stream-result" + + from openlayer.lib.integrations.portkey_tracer import trace_portkey + + trace_portkey() + + client = DummyPortkey() + # Non-streaming + result_non_stream = client.chat.completions.create(messages=[{"role": "user", "content": "hi"}]) + assert result_non_stream == "non-stream-result" + assert mock_non_streaming.call_count == 1 + non_stream_kwargs = mock_non_streaming.call_args.kwargs + assert non_stream_kwargs["create_func"] is client.original_create + assert non_stream_kwargs["inference_id"] is None + + # Streaming path + result_stream = client.chat.completions.create( + messages=[{"role": "user", "content": "hi"}], stream=True, inference_id="inference-123" + ) + assert result_stream == "stream-result" + assert mock_streaming.call_count == 1 + stream_kwargs = mock_streaming.call_args.kwargs + assert stream_kwargs["create_func"] is client.original_create + assert stream_kwargs["inference_id"] == "inference-123" + + def test_detect_provider_from_model_name(self) -> None: + """Provider detection should match model naming heuristics.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider_from_model_name + + test_cases = [ + ("gpt-4", "OpenAI"), + ("Gpt-3.5-turbo", "OpenAI"), + ("claude-3-opus", "Anthropic"), + ("gemini-pro", "Google"), + ("meta-llama-3-70b", "Meta"), + ("mixtral-8x7b", "Mistral"), + ("command-r", "Cohere"), + ("unknown-model", "Portkey"), + ] + + for model_name, expected in test_cases: + assert detect_provider_from_model_name(model_name) == expected + + def test_get_model_parameters(self) -> None: + """Ensure OpenAI-compatible kwargs are extracted.""" + from openlayer.lib.integrations.portkey_tracer import get_model_parameters + + kwargs = { + "temperature": 0.5, + "top_p": 0.7, + "max_tokens": 256, + "n": 3, + "stream": True, + "stop": ["END"], + "presence_penalty": 0.1, + "frequency_penalty": -0.1, + "logit_bias": {"1": -1}, + "logprobs": True, + "top_logprobs": 5, + "parallel_tool_calls": False, + "seed": 123, + "response_format": {"type": "json_object"}, + "timeout": 42, + "api_base": "https://api.example.com", + "api_version": "2024-05-01", + } + + params = get_model_parameters(kwargs) + + expected = kwargs.copy() + assert params == expected + + def test_extract_portkey_metadata(self) -> None: + """Portkey metadata should redact sensitive headers and include base URL.""" + from openlayer.lib.integrations.portkey_tracer import extract_portkey_metadata + + client = SimpleNamespace( + base_url="https://gateway.portkey.ai", + headers={ + "X-Portkey-Api-Key": "secret", + "X-Portkey-Provider": "openai", + "Authorization": "Bearer ignored", + }, + ) + + metadata = extract_portkey_metadata(client) + + assert metadata["isPortkey"] is True + assert metadata["portkeyBaseUrl"] == "https://gateway.portkey.ai" + assert metadata["portkeyHeaders"]["X-Portkey-Api-Key"] == "***" + assert metadata["portkeyHeaders"]["X-Portkey-Provider"] == "openai" + assert "Authorization" not in metadata["portkeyHeaders"] + + def test_extract_portkey_unit_metadata(self) -> None: + """Unit metadata should capture headers and retry/option index hints.""" + from openlayer.lib.integrations.portkey_tracer import extract_portkey_unit_metadata + + unit = SimpleNamespace( + system_fingerprint="fingerprint-123", + _response_headers={ + "x-portkey-trace-id": "trace-1", + "x-portkey-cache-status": "HIT", + "x-portkey-retry-attempt-count": "2", + "x-portkey-last-used-option-index": "config.targets[1]", + "content-type": "application/json", + }, + ) + + metadata = extract_portkey_unit_metadata(unit, "claude-3-opus") + + assert metadata["system_fingerprint"] == "fingerprint-123" + assert metadata["portkey_trace_id"] == "trace-1" + assert metadata["portkey_cache_status"] == "HIT" + assert metadata["portkey_retry_attempt_count"] == "2" + assert metadata["portkey_last_used_option_index"] == "config.targets[1]" + assert metadata["portkey_model"] == "claude-3-opus" + assert metadata["response_headers"]["content-type"] == "application/json" + + def test_extract_portkey_unit_metadata_with_dict_like_headers(self) -> None: + """Unit metadata should work with dict-like objects (not just dicts).""" + from openlayer.lib.integrations.portkey_tracer import extract_portkey_unit_metadata + + # Create a dict-like object (has .items() but not isinstance(dict)) + class DictLikeHeaders: + def __init__(self): + self._data = { + "x-portkey-trace-id": "trace-2", + "x-portkey-cache-status": "MISS", + "x-portkey-retry-attempt-count": "3", + "x-portkey-last-used-option-index": "config.targets[1]", + } + + def items(self): + return self._data.items() + + unit = SimpleNamespace( + _response_headers=DictLikeHeaders(), + ) + + metadata = extract_portkey_unit_metadata(unit, "gpt-4") + + assert metadata["portkey_trace_id"] == "trace-2" + assert metadata["portkey_cache_status"] == "MISS" + assert metadata["portkey_retry_attempt_count"] == "3" + assert metadata["portkey_last_used_option_index"] == "config.targets[1]" + + def test_extract_usage_from_response(self) -> None: + """Usage extraction should read OpenAI-style usage objects.""" + from openlayer.lib.integrations.portkey_tracer import extract_usage + + usage = SimpleNamespace(total_tokens=50, prompt_tokens=20, completion_tokens=30) + response = SimpleNamespace(usage=usage) + + assert extract_usage(response) == { + "total_tokens": 50, + "prompt_tokens": 20, + "completion_tokens": 30, + } + + response_no_usage = SimpleNamespace() + assert extract_usage(response_no_usage) == { + "total_tokens": None, + "prompt_tokens": None, + "completion_tokens": None, + } + + def test_extract_usage_from_chunk(self) -> None: + """Usage data should be derived from multiple potential chunk attributes.""" + from openlayer.lib.integrations.portkey_tracer import extract_usage + + chunk_direct = SimpleNamespace( + usage=SimpleNamespace(total_tokens=120, prompt_tokens=40, completion_tokens=80) + ) + assert extract_usage(chunk_direct) == { + "total_tokens": 120, + "prompt_tokens": 40, + "completion_tokens": 80, + } + + class ChunkWithModelDump: # pylint: disable=too-few-public-methods + def model_dump(self) -> Dict[str, Any]: + return {"usage": {"total_tokens": 12, "prompt_tokens": 5, "completion_tokens": 7}} + + assert extract_usage(ChunkWithModelDump()) == { + "total_tokens": 12, + "prompt_tokens": 5, + "completion_tokens": 7, + } + + def test_calculate_streaming_usage_and_cost_with_actual_usage(self) -> None: + """Actual usage data should be returned when available.""" + from openlayer.lib.integrations.portkey_tracer import calculate_streaming_usage_and_cost + + latest_usage = {"total_tokens": 100, "prompt_tokens": 40, "completion_tokens": 60} + latest_metadata = {"cost": 0.99} + + result = calculate_streaming_usage_and_cost( + chunks=[], + messages=[], + output_content="", + model_name="gpt-4", + latest_usage_data=latest_usage, + latest_chunk_metadata=latest_metadata, + ) + + assert result == (60, 40, 100, 0.99) + + def test_calculate_streaming_usage_and_cost_fallback_estimation(self) -> None: + """Fallback estimation should approximate tokens and cost when usage is missing.""" + from openlayer.lib.integrations.portkey_tracer import calculate_streaming_usage_and_cost + + output_content = "Generated answer text." + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Tell me something interesting."}, + ] + + completion_tokens, prompt_tokens, total_tokens, cost = calculate_streaming_usage_and_cost( + chunks=[{"usage": None}], + messages=messages, + output_content=output_content, + model_name="gpt-3.5-turbo", + latest_usage_data={"total_tokens": None, "prompt_tokens": None, "completion_tokens": None}, + latest_chunk_metadata={}, + ) + + assert completion_tokens >= 1 + assert prompt_tokens >= 1 + assert total_tokens == (completion_tokens or 0) + (prompt_tokens or 0) + assert cost is not None + assert cost >= 0 + + def test_detect_provider_from_response_prefers_headers(self) -> None: + """Provider detection should prioritize Portkey headers.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider + + client = SimpleNamespace() + response = SimpleNamespace() + + with patch( + "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider" + ): + assert detect_provider(response, client, "gpt-4") == "header-provider" + + def test_detect_provider_from_chunk_prefers_headers(self) -> None: + """Provider detection from chunk should prioritize header-derived values.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider + + client = SimpleNamespace() + chunk = SimpleNamespace() + + with patch( + "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider" + ): + assert detect_provider(chunk, client, "gpt-4") == "header-provider" + + def test_detect_provider_from_response_fallback(self) -> None: + """Provider detection should fall back to response metadata or model name.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider + + client = SimpleNamespace() + response = SimpleNamespace( + response_metadata={"provider": "anthropic"}, + ) + + with patch( + "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None + ): + assert detect_provider(response, client, "mistral-7b") == "anthropic" + + def test_detect_provider_from_chunk_fallback(self) -> None: + """Chunk provider detection should fall back gracefully.""" + from openlayer.lib.integrations.portkey_tracer import detect_provider + + chunk = SimpleNamespace( + response_metadata={"provider": "cohere"}, + ) + client = SimpleNamespace() + + with patch( + "openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None + ): + assert detect_provider(chunk, client, "command-r") == "cohere" + + def test_provider_from_portkey_headers(self) -> None: + """Header helper should identify provider values on the client.""" + from openlayer.lib.integrations.portkey_tracer import _provider_from_portkey_headers + + client = SimpleNamespace( + default_headers={"X-Portkey-Provider": "openai"}, + headers={"X-Portkey-Provider": "anthropic"}, + ) + + assert _provider_from_portkey_headers(client) == "openai" + + def test_parse_non_streaming_output_data(self) -> None: + """Output parsing should support content, function calls, and tool calls.""" + from openlayer.lib.integrations.portkey_tracer import parse_non_streaming_output_data + + # Content message + response_content = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", function_call=None, tool_calls=None))] + ) + assert parse_non_streaming_output_data(response_content) == "Hello!" + + # Function call + response_function = SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace( + content=None, + function_call=SimpleNamespace(name="do_something", arguments=json.dumps({"value": 1})), + tool_calls=None, + ) + ) + ] + ) + assert parse_non_streaming_output_data(response_function) == {"name": "do_something", "arguments": {"value": 1}} + + # Tool call + response_tool = SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace( + content=None, + function_call=None, + tool_calls=[ + SimpleNamespace( + function=SimpleNamespace(name="call_tool", arguments=json.dumps({"value": 2})) + ) + ], + ) + ) + ] + ) + assert parse_non_streaming_output_data(response_tool) == {"name": "call_tool", "arguments": {"value": 2}} + + def test_create_trace_args(self) -> None: + """Trace argument helper should include optional id and cost.""" + from openlayer.lib.integrations.portkey_tracer import create_trace_args + + args = create_trace_args( + end_time=1.0, + inputs={"prompt": []}, + output="response", + latency=123.0, + tokens=10, + prompt_tokens=4, + completion_tokens=6, + model="gpt-4", + id="trace-id", + cost=0.42, + ) + + assert args["id"] == "trace-id" + assert args["cost"] == 0.42 + assert args["metadata"] == {} + + def test_add_to_trace_uses_provider_metadata(self) -> None: + """add_to_trace should pass provider metadata through to tracer.""" + from openlayer.lib.integrations.portkey_tracer import add_to_trace + + with patch( + "openlayer.lib.integrations.portkey_tracer.tracer.add_chat_completion_step_to_trace" + ) as mock_add: + add_to_trace( + end_time=1.0, + inputs={}, + output=None, + latency=10.0, + tokens=None, + prompt_tokens=None, + completion_tokens=None, + model="model", + metadata={}, + ) + + _, kwargs = mock_add.call_args + assert kwargs["provider"] == "Portkey" + assert kwargs["name"] == "Portkey Chat Completion" + + add_to_trace( + end_time=2.0, + inputs={}, + output=None, + latency=5.0, + tokens=None, + prompt_tokens=None, + completion_tokens=None, + model="model", + metadata={"provider": "OpenAI"}, + ) + + assert mock_add.call_count == 2 + assert mock_add.call_args.kwargs["provider"] == "OpenAI" + + def test_handle_streaming_create_delegates_to_stream_chunks(self) -> None: + """handle_streaming_create should call the original create and stream_chunks.""" + from openlayer.lib.integrations.portkey_tracer import handle_streaming_create + + client = SimpleNamespace() + create_func = Mock(return_value=iter(["chunk"])) + + with patch( + "openlayer.lib.integrations.portkey_tracer.stream_chunks", return_value=iter(["chunk"]) + ) as mock_stream_chunks: + result_iterator = handle_streaming_create( + client, + "arg-1", + create_func=create_func, + inference_id="stream-id", + foo="bar", + ) + + assert list(result_iterator) == ["chunk"] + create_func.assert_called_once_with("arg-1", foo="bar") + mock_stream_chunks.assert_called_once() + stream_kwargs = mock_stream_chunks.call_args.kwargs + assert stream_kwargs["client"] is client + assert stream_kwargs["inference_id"] == "stream-id" + assert stream_kwargs["kwargs"] == {"foo": "bar"} + assert stream_kwargs["chunks"] is create_func.return_value + + def test_stream_chunks_traces_completion(self) -> None: + """stream_chunks should yield all chunks and record a traced step.""" + from openlayer.lib.integrations.portkey_tracer import stream_chunks + + chunk_a = object() + chunk_b = object() + chunks = [chunk_a, chunk_b] + kwargs = {"messages": [{"role": "user", "content": "hello"}]} + client = SimpleNamespace() + + with patch( + "openlayer.lib.integrations.portkey_tracer.add_to_trace", autospec=True + ) as mock_add_to_trace, patch( + "openlayer.lib.integrations.portkey_tracer.extract_usage", autospec=True + ) as mock_usage, patch( + "openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata", autospec=True + ) as mock_unit_metadata, patch( + "openlayer.lib.integrations.portkey_tracer.detect_provider", autospec=True + ) as mock_detect_provider, patch( + "openlayer.lib.integrations.portkey_tracer.get_delta_from_chunk", autospec=True + ) as mock_delta, patch( + "openlayer.lib.integrations.portkey_tracer.calculate_streaming_usage_and_cost", autospec=True + ) as mock_calc, patch( + "openlayer.lib.integrations.portkey_tracer.extract_portkey_metadata", autospec=True + ) as mock_client_metadata, patch( + "openlayer.lib.integrations.portkey_tracer.time.time", side_effect=[100.0, 100.05, 100.2] + ): + mock_usage.side_effect = [ + {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None}, + {"total_tokens": 10, "prompt_tokens": 4, "completion_tokens": 6}, + ] + mock_unit_metadata.side_effect = [{}, {"cost": 0.1}] + mock_detect_provider.side_effect = ["OpenAI", "OpenAI"] + mock_delta.side_effect = [ + SimpleNamespace(content="Hello ", function_call=None, tool_calls=None), + SimpleNamespace(content="World", function_call=None, tool_calls=None), + ] + mock_calc.return_value = (6, 4, 10, 0.1) + mock_client_metadata.return_value = {"portkeyBaseUrl": "https://gateway"} + + yielded = list( + stream_chunks( + chunks=iter(chunks), + kwargs=kwargs, + client=client, + inference_id="trace-123", + ) + ) + + assert yielded == chunks + mock_add_to_trace.assert_called_once() + trace_kwargs = mock_add_to_trace.call_args.kwargs + assert trace_kwargs["metadata"]["provider"] == "OpenAI" + assert trace_kwargs["metadata"]["portkeyBaseUrl"] == "https://gateway" + assert trace_kwargs["id"] == "trace-123" + assert trace_kwargs["tokens"] == 10 + assert trace_kwargs["latency"] == pytest.approx(200.0) + + def test_handle_non_streaming_create_traces_completion(self) -> None: + """handle_non_streaming_create should record a traced step for completions.""" + from openlayer.lib.integrations.portkey_tracer import handle_non_streaming_create + + response = SimpleNamespace(model="gpt-4", system_fingerprint="fp-1") + client = SimpleNamespace() + create_func = Mock(return_value=response) + + with patch( + "openlayer.lib.integrations.portkey_tracer.parse_non_streaming_output_data", return_value="output" + ), patch( + "openlayer.lib.integrations.portkey_tracer.extract_usage", + return_value={"total_tokens": 10, "prompt_tokens": 4, "completion_tokens": 6}, + ), patch( + "openlayer.lib.integrations.portkey_tracer.detect_provider", return_value="OpenAI" + ), patch( + "openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata", + return_value={"cost": 0.25}, + ), patch( + "openlayer.lib.integrations.portkey_tracer.extract_portkey_metadata", + return_value={"portkeyHeaders": {"X-Portkey-Provider": "openai"}}, + ), patch( + "openlayer.lib.integrations.portkey_tracer.add_to_trace" + ) as mock_add_to_trace, patch( + "openlayer.lib.integrations.portkey_tracer.time.time", side_effect=[10.0, 10.2] + ): + result = handle_non_streaming_create( + client, + create_func=create_func, + inference_id="trace-xyz", + messages=[{"role": "user", "content": "Hello"}], + ) + + assert result is response + mock_add_to_trace.assert_called_once() + trace_kwargs = mock_add_to_trace.call_args.kwargs + assert trace_kwargs["id"] == "trace-xyz" + assert trace_kwargs["metadata"]["provider"] == "OpenAI" + assert trace_kwargs["metadata"]["cost"] == 0.25 + assert trace_kwargs["metadata"]["portkeyHeaders"]["X-Portkey-Provider"] == "openai" + +