Skip to content

Commit 23618da

Browse files
refactoring to increase re-usability of code
1 parent a4c1d5a commit 23618da

File tree

2 files changed

+90
-82
lines changed

2 files changed

+90
-82
lines changed

src/openlayer/lib/integrations/portkey_tracer.py

Lines changed: 71 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,12 @@ def stream_chunks(
185185
if i == 0:
186186
first_token_time = time.time()
187187
# Try to detect provider at first chunk
188-
provider = detect_provider_from_chunk(chunk, client, model_name)
188+
provider = detect_provider(chunk, client, model_name)
189189
if i > 0:
190190
num_of_completion_tokens = i + 1
191191

192192
# Extract usage from chunk if available
193-
chunk_usage = extract_usage_from_chunk(chunk)
193+
chunk_usage = extract_usage(chunk)
194194
if any(v is not None for v in chunk_usage.values()):
195195
latest_usage_data = chunk_usage
196196

@@ -317,9 +317,9 @@ def handle_non_streaming_create(
317317
output_data = parse_non_streaming_output_data(response)
318318

319319
# Usage (if provided by upstream provider via Portkey)
320-
usage_data = extract_usage_from_response(response)
320+
usage_data = extract_usage(response)
321321
model_name = getattr(response, "model", kwargs.get("model", "unknown"))
322-
provider = detect_provider_from_response(response, client, model_name)
322+
provider = detect_provider(response, client, model_name)
323323
extra_metadata = extract_portkey_unit_metadata(response, model_name)
324324
cost = extra_metadata.get("cost", None)
325325

@@ -530,36 +530,39 @@ def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]:
530530
return metadata
531531

532532

533-
def extract_usage_from_response(response: Any) -> Dict[str, Optional[int]]:
534-
"""Extract usage from a non-streaming response."""
535-
try:
536-
if hasattr(response, "usage") and response.usage is not None:
537-
usage = response.usage
538-
return {
539-
"total_tokens": getattr(usage, "total_tokens", None),
540-
"prompt_tokens": getattr(usage, "prompt_tokens", None),
541-
"completion_tokens": getattr(usage, "completion_tokens", None),
542-
}
543-
except Exception:
544-
pass
545-
return {"total_tokens": None, "prompt_tokens": None, "completion_tokens": None}
546-
547-
548-
def extract_usage_from_chunk(chunk: Any) -> Dict[str, Optional[int]]:
549-
"""Extract usage from a streaming chunk if present."""
533+
def extract_usage(obj: Any) -> Dict[str, Optional[int]]:
534+
"""Extract usage from a response or chunk object.
535+
536+
This function attempts to extract token usage information from various
537+
locations where it might be stored, including:
538+
- Direct `usage` attribute
539+
- `_hidden_params` (for streaming chunks)
540+
- `model_dump()` dictionary (for streaming chunks)
541+
542+
Parameters
543+
----------
544+
obj : Any
545+
The response or chunk object to extract usage from.
546+
547+
Returns
548+
-------
549+
Dict[str, Optional[int]]
550+
Dictionary with keys: total_tokens, prompt_tokens, completion_tokens.
551+
Values are None if usage information is not found.
552+
"""
550553
try:
551-
# Check for usage attribute
552-
if hasattr(chunk, "usage") and chunk.usage is not None:
553-
usage = chunk.usage
554+
# Check for direct usage attribute (works for both response and chunk)
555+
if hasattr(obj, "usage") and obj.usage is not None:
556+
usage = obj.usage
554557
return {
555558
"total_tokens": getattr(usage, "total_tokens", None),
556559
"prompt_tokens": getattr(usage, "prompt_tokens", None),
557560
"completion_tokens": getattr(usage, "completion_tokens", None),
558561
}
559562

560-
# Check for usage in _hidden_params (if SDK stores it there)
561-
if hasattr(chunk, "_hidden_params"):
562-
hidden_params = chunk._hidden_params
563+
# Check for usage in _hidden_params (primarily for streaming chunks)
564+
if hasattr(obj, "_hidden_params"):
565+
hidden_params = obj._hidden_params
563566
# Check if usage is a direct attribute
564567
if hasattr(hidden_params, "usage") and hidden_params.usage is not None:
565568
usage = hidden_params.usage
@@ -578,11 +581,11 @@ def extract_usage_from_chunk(chunk: Any) -> Dict[str, Optional[int]]:
578581
"completion_tokens": usage.get("completion_tokens", None),
579582
}
580583

581-
# Check if chunk model dump has usage
582-
if hasattr(chunk, "model_dump"):
583-
chunk_dict = chunk.model_dump()
584-
if _supports_membership_check(chunk_dict) and "usage" in chunk_dict and chunk_dict["usage"]:
585-
usage = chunk_dict["usage"]
584+
# Check if object model dump has usage (primarily for streaming chunks)
585+
if hasattr(obj, "model_dump"):
586+
obj_dict = obj.model_dump()
587+
if _supports_membership_check(obj_dict) and "usage" in obj_dict and obj_dict["usage"]:
588+
usage = obj_dict["usage"]
586589
return {
587590
"total_tokens": usage.get("total_tokens", None),
588591
"prompt_tokens": usage.get("prompt_tokens", None),
@@ -672,49 +675,54 @@ def calculate_streaming_usage_and_cost(
672675
return None, None, None, None
673676

674677

675-
def detect_provider_from_response(response: Any, client: "Portkey", model_name: str) -> str:
676-
"""Detect provider for non-streaming responses."""
677-
# First: check Portkey headers on the client (authoritative)
678-
provider = _provider_from_portkey_headers(client)
679-
if provider:
680-
return provider
681-
# Next: check response metadata if any
678+
def _extract_provider_from_object(obj: Any) -> Optional[str]:
679+
"""Extract provider from a response or chunk object.
680+
681+
Checks response_metadata and _response_headers for provider information.
682+
Returns None if no provider is found.
683+
"""
682684
try:
683-
# Some SDKs attach response headers/metadata
684-
if hasattr(response, "response_metadata") and _is_dict_like(response.response_metadata):
685-
if "provider" in response.response_metadata:
686-
return response.response_metadata["provider"]
687-
if hasattr(response, "_response_headers"):
688-
headers = getattr(response, "_response_headers")
685+
# Check response_metadata
686+
if hasattr(obj, "response_metadata") and _is_dict_like(obj.response_metadata):
687+
if "provider" in obj.response_metadata:
688+
return obj.response_metadata["provider"]
689+
# Check _response_headers
690+
if hasattr(obj, "_response_headers"):
691+
headers = getattr(obj, "_response_headers")
689692
if _is_dict_like(headers):
690693
for k, v in headers.items():
691694
if isinstance(k, str) and k.lower() == "x-portkey-provider" and v:
692695
return str(v)
693696
except Exception:
694697
pass
695-
# Fallback to model name heuristics
696-
return detect_provider_from_model_name(model_name)
698+
return None
697699

698700

699-
def detect_provider_from_chunk(chunk: Any, client: "Portkey", model_name: str) -> str:
700-
"""Detect provider for streaming chunks."""
701-
# First: check Portkey headers on the client
701+
def detect_provider(obj: Any, client: "Portkey", model_name: str) -> str:
702+
"""Detect provider from a response or chunk object.
703+
704+
Parameters
705+
----------
706+
obj : Any
707+
The response or chunk object to extract provider information from.
708+
client : Portkey
709+
The Portkey client instance.
710+
model_name : str
711+
The model name to use as a fallback for provider detection.
712+
713+
Returns
714+
-------
715+
str
716+
The detected provider name.
717+
"""
718+
# First: check Portkey headers on the client (authoritative)
702719
provider = _provider_from_portkey_headers(client)
703720
if provider:
704721
return provider
705-
# Next: see if chunk exposes any metadata
706-
try:
707-
if hasattr(chunk, "response_metadata") and _is_dict_like(chunk.response_metadata):
708-
if "provider" in chunk.response_metadata:
709-
return chunk.response_metadata["provider"]
710-
if hasattr(chunk, "_response_headers"):
711-
headers = getattr(chunk, "_response_headers")
712-
if _is_dict_like(headers):
713-
for k, v in headers.items():
714-
if isinstance(k, str) and k.lower() == "x-portkey-provider" and v:
715-
return str(v)
716-
except Exception:
717-
pass
722+
# Next: check object metadata if any
723+
provider = _extract_provider_from_object(obj)
724+
if provider:
725+
return provider
718726
# Fallback to model name heuristics
719727
return detect_provider_from_model_name(model_name)
720728

tests/test_portkey_integration.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -197,32 +197,32 @@ def items(self):
197197

198198
def test_extract_usage_from_response(self) -> None:
199199
"""Usage extraction should read OpenAI-style usage objects."""
200-
from openlayer.lib.integrations.portkey_tracer import extract_usage_from_response
200+
from openlayer.lib.integrations.portkey_tracer import extract_usage
201201

202202
usage = SimpleNamespace(total_tokens=50, prompt_tokens=20, completion_tokens=30)
203203
response = SimpleNamespace(usage=usage)
204204

205-
assert extract_usage_from_response(response) == {
205+
assert extract_usage(response) == {
206206
"total_tokens": 50,
207207
"prompt_tokens": 20,
208208
"completion_tokens": 30,
209209
}
210210

211211
response_no_usage = SimpleNamespace()
212-
assert extract_usage_from_response(response_no_usage) == {
212+
assert extract_usage(response_no_usage) == {
213213
"total_tokens": None,
214214
"prompt_tokens": None,
215215
"completion_tokens": None,
216216
}
217217

218218
def test_extract_usage_from_chunk(self) -> None:
219219
"""Usage data should be derived from multiple potential chunk attributes."""
220-
from openlayer.lib.integrations.portkey_tracer import extract_usage_from_chunk
220+
from openlayer.lib.integrations.portkey_tracer import extract_usage
221221

222222
chunk_direct = SimpleNamespace(
223223
usage=SimpleNamespace(total_tokens=120, prompt_tokens=40, completion_tokens=80)
224224
)
225-
assert extract_usage_from_chunk(chunk_direct) == {
225+
assert extract_usage(chunk_direct) == {
226226
"total_tokens": 120,
227227
"prompt_tokens": 40,
228228
"completion_tokens": 80,
@@ -231,7 +231,7 @@ def test_extract_usage_from_chunk(self) -> None:
231231
chunk_hidden = SimpleNamespace(
232232
_hidden_params={"usage": {"total_tokens": 30, "prompt_tokens": 10, "completion_tokens": 20}}
233233
)
234-
assert extract_usage_from_chunk(chunk_hidden) == {
234+
assert extract_usage(chunk_hidden) == {
235235
"total_tokens": 30,
236236
"prompt_tokens": 10,
237237
"completion_tokens": 20,
@@ -241,7 +241,7 @@ class ChunkWithModelDump: # pylint: disable=too-few-public-methods
241241
def model_dump(self) -> Dict[str, Any]:
242242
return {"usage": {"total_tokens": 12, "prompt_tokens": 5, "completion_tokens": 7}}
243243

244-
assert extract_usage_from_chunk(ChunkWithModelDump()) == {
244+
assert extract_usage(ChunkWithModelDump()) == {
245245
"total_tokens": 12,
246246
"prompt_tokens": 5,
247247
"completion_tokens": 7,
@@ -292,31 +292,31 @@ def test_calculate_streaming_usage_and_cost_fallback_estimation(self) -> None:
292292

293293
def test_detect_provider_from_response_prefers_headers(self) -> None:
294294
"""Provider detection should prioritize Portkey headers."""
295-
from openlayer.lib.integrations.portkey_tracer import detect_provider_from_response
295+
from openlayer.lib.integrations.portkey_tracer import detect_provider
296296

297297
client = SimpleNamespace()
298298
response = SimpleNamespace()
299299

300300
with patch(
301301
"openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider"
302302
):
303-
assert detect_provider_from_response(response, client, "gpt-4") == "header-provider"
303+
assert detect_provider(response, client, "gpt-4") == "header-provider"
304304

305305
def test_detect_provider_from_chunk_prefers_headers(self) -> None:
306306
"""Provider detection from chunk should prioritize header-derived values."""
307-
from openlayer.lib.integrations.portkey_tracer import detect_provider_from_chunk
307+
from openlayer.lib.integrations.portkey_tracer import detect_provider
308308

309309
client = SimpleNamespace()
310310
chunk = SimpleNamespace()
311311

312312
with patch(
313313
"openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value="header-provider"
314314
):
315-
assert detect_provider_from_chunk(chunk, client, "gpt-4") == "header-provider"
315+
assert detect_provider(chunk, client, "gpt-4") == "header-provider"
316316

317317
def test_detect_provider_from_response_fallback(self) -> None:
318318
"""Provider detection should fall back to response metadata or model name."""
319-
from openlayer.lib.integrations.portkey_tracer import detect_provider_from_response
319+
from openlayer.lib.integrations.portkey_tracer import detect_provider
320320

321321
client = SimpleNamespace(headers={"x-portkey-provider": "openai"})
322322
response = SimpleNamespace(
@@ -327,11 +327,11 @@ def test_detect_provider_from_response_fallback(self) -> None:
327327
with patch(
328328
"openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None
329329
):
330-
assert detect_provider_from_response(response, client, "mistral-7b") == "anthropic"
330+
assert detect_provider(response, client, "mistral-7b") == "anthropic"
331331

332332
def test_detect_provider_from_chunk_fallback(self) -> None:
333333
"""Chunk provider detection should fall back gracefully."""
334-
from openlayer.lib.integrations.portkey_tracer import detect_provider_from_chunk
334+
from openlayer.lib.integrations.portkey_tracer import detect_provider
335335

336336
chunk = SimpleNamespace(
337337
response_metadata={"provider": "cohere"},
@@ -342,7 +342,7 @@ def test_detect_provider_from_chunk_fallback(self) -> None:
342342
with patch(
343343
"openlayer.lib.integrations.portkey_tracer._provider_from_portkey_headers", return_value=None
344344
):
345-
assert detect_provider_from_chunk(chunk, client, "command-r") == "cohere"
345+
assert detect_provider(chunk, client, "command-r") == "cohere"
346346

347347
def test_provider_from_portkey_headers(self) -> None:
348348
"""Header helper should identify provider values on the client."""
@@ -496,11 +496,11 @@ def test_stream_chunks_traces_completion(self) -> None:
496496
with patch(
497497
"openlayer.lib.integrations.portkey_tracer.add_to_trace", autospec=True
498498
) as mock_add_to_trace, patch(
499-
"openlayer.lib.integrations.portkey_tracer.extract_usage_from_chunk", autospec=True
499+
"openlayer.lib.integrations.portkey_tracer.extract_usage", autospec=True
500500
) as mock_usage, patch(
501501
"openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata", autospec=True
502502
) as mock_unit_metadata, patch(
503-
"openlayer.lib.integrations.portkey_tracer.detect_provider_from_chunk", autospec=True
503+
"openlayer.lib.integrations.portkey_tracer.detect_provider", autospec=True
504504
) as mock_detect_provider, patch(
505505
"openlayer.lib.integrations.portkey_tracer.get_delta_from_chunk", autospec=True
506506
) as mock_delta, patch(
@@ -552,10 +552,10 @@ def test_handle_non_streaming_create_traces_completion(self) -> None:
552552
with patch(
553553
"openlayer.lib.integrations.portkey_tracer.parse_non_streaming_output_data", return_value="output"
554554
), patch(
555-
"openlayer.lib.integrations.portkey_tracer.extract_usage_from_response",
555+
"openlayer.lib.integrations.portkey_tracer.extract_usage",
556556
return_value={"total_tokens": 10, "prompt_tokens": 4, "completion_tokens": 6},
557557
), patch(
558-
"openlayer.lib.integrations.portkey_tracer.detect_provider_from_response", return_value="OpenAI"
558+
"openlayer.lib.integrations.portkey_tracer.detect_provider", return_value="OpenAI"
559559
), patch(
560560
"openlayer.lib.integrations.portkey_tracer.extract_portkey_unit_metadata",
561561
return_value={"cost": 0.25},

0 commit comments

Comments
 (0)