@@ -185,12 +185,12 @@ def stream_chunks(
185185 if i == 0 :
186186 first_token_time = time .time ()
187187 # Try to detect provider at first chunk
188- provider = detect_provider_from_chunk (chunk , client , model_name )
188+ provider = detect_provider (chunk , client , model_name )
189189 if i > 0 :
190190 num_of_completion_tokens = i + 1
191191
192192 # Extract usage from chunk if available
193- chunk_usage = extract_usage_from_chunk (chunk )
193+ chunk_usage = extract_usage (chunk )
194194 if any (v is not None for v in chunk_usage .values ()):
195195 latest_usage_data = chunk_usage
196196
@@ -317,9 +317,9 @@ def handle_non_streaming_create(
317317 output_data = parse_non_streaming_output_data (response )
318318
319319 # Usage (if provided by upstream provider via Portkey)
320- usage_data = extract_usage_from_response (response )
320+ usage_data = extract_usage (response )
321321 model_name = getattr (response , "model" , kwargs .get ("model" , "unknown" ))
322- provider = detect_provider_from_response (response , client , model_name )
322+ provider = detect_provider (response , client , model_name )
323323 extra_metadata = extract_portkey_unit_metadata (response , model_name )
324324 cost = extra_metadata .get ("cost" , None )
325325
@@ -530,36 +530,39 @@ def extract_portkey_unit_metadata(unit: Any, model_name: str) -> Dict[str, Any]:
530530 return metadata
531531
532532
533- def extract_usage_from_response (response : Any ) -> Dict [str , Optional [int ]]:
534- """Extract usage from a non-streaming response."""
535- try :
536- if hasattr (response , "usage" ) and response .usage is not None :
537- usage = response .usage
538- return {
539- "total_tokens" : getattr (usage , "total_tokens" , None ),
540- "prompt_tokens" : getattr (usage , "prompt_tokens" , None ),
541- "completion_tokens" : getattr (usage , "completion_tokens" , None ),
542- }
543- except Exception :
544- pass
545- return {"total_tokens" : None , "prompt_tokens" : None , "completion_tokens" : None }
546-
547-
548- def extract_usage_from_chunk (chunk : Any ) -> Dict [str , Optional [int ]]:
549- """Extract usage from a streaming chunk if present."""
533+ def extract_usage (obj : Any ) -> Dict [str , Optional [int ]]:
534+ """Extract usage from a response or chunk object.
535+
536+ This function attempts to extract token usage information from various
537+ locations where it might be stored, including:
538+ - Direct `usage` attribute
539+ - `_hidden_params` (for streaming chunks)
540+ - `model_dump()` dictionary (for streaming chunks)
541+
542+ Parameters
543+ ----------
544+ obj : Any
545+ The response or chunk object to extract usage from.
546+
547+ Returns
548+ -------
549+ Dict[str, Optional[int]]
550+ Dictionary with keys: total_tokens, prompt_tokens, completion_tokens.
551+ Values are None if usage information is not found.
552+ """
550553 try :
551- # Check for usage attribute
552- if hasattr (chunk , "usage" ) and chunk .usage is not None :
553- usage = chunk .usage
554+ # Check for direct usage attribute (works for both response and chunk)
555+ if hasattr (obj , "usage" ) and obj .usage is not None :
556+ usage = obj .usage
554557 return {
555558 "total_tokens" : getattr (usage , "total_tokens" , None ),
556559 "prompt_tokens" : getattr (usage , "prompt_tokens" , None ),
557560 "completion_tokens" : getattr (usage , "completion_tokens" , None ),
558561 }
559562
560- # Check for usage in _hidden_params (if SDK stores it there )
561- if hasattr (chunk , "_hidden_params" ):
562- hidden_params = chunk ._hidden_params
563+ # Check for usage in _hidden_params (primarily for streaming chunks )
564+ if hasattr (obj , "_hidden_params" ):
565+ hidden_params = obj ._hidden_params
563566 # Check if usage is a direct attribute
564567 if hasattr (hidden_params , "usage" ) and hidden_params .usage is not None :
565568 usage = hidden_params .usage
@@ -578,11 +581,11 @@ def extract_usage_from_chunk(chunk: Any) -> Dict[str, Optional[int]]:
578581 "completion_tokens" : usage .get ("completion_tokens" , None ),
579582 }
580583
581- # Check if chunk model dump has usage
582- if hasattr (chunk , "model_dump" ):
583- chunk_dict = chunk .model_dump ()
584- if _supports_membership_check (chunk_dict ) and "usage" in chunk_dict and chunk_dict ["usage" ]:
585- usage = chunk_dict ["usage" ]
584+ # Check if object model dump has usage (primarily for streaming chunks)
585+ if hasattr (obj , "model_dump" ):
586+ obj_dict = obj .model_dump ()
587+ if _supports_membership_check (obj_dict ) and "usage" in obj_dict and obj_dict ["usage" ]:
588+ usage = obj_dict ["usage" ]
586589 return {
587590 "total_tokens" : usage .get ("total_tokens" , None ),
588591 "prompt_tokens" : usage .get ("prompt_tokens" , None ),
@@ -672,49 +675,54 @@ def calculate_streaming_usage_and_cost(
672675 return None , None , None , None
673676
674677
675- def detect_provider_from_response (response : Any , client : "Portkey" , model_name : str ) -> str :
676- """Detect provider for non-streaming responses."""
677- # First: check Portkey headers on the client (authoritative)
678- provider = _provider_from_portkey_headers (client )
679- if provider :
680- return provider
681- # Next: check response metadata if any
678+ def _extract_provider_from_object (obj : Any ) -> Optional [str ]:
679+ """Extract provider from a response or chunk object.
680+
681+ Checks response_metadata and _response_headers for provider information.
682+ Returns None if no provider is found.
683+ """
682684 try :
683- # Some SDKs attach response headers/metadata
684- if hasattr (response , "response_metadata" ) and _is_dict_like (response .response_metadata ):
685- if "provider" in response .response_metadata :
686- return response .response_metadata ["provider" ]
687- if hasattr (response , "_response_headers" ):
688- headers = getattr (response , "_response_headers" )
685+ # Check response_metadata
686+ if hasattr (obj , "response_metadata" ) and _is_dict_like (obj .response_metadata ):
687+ if "provider" in obj .response_metadata :
688+ return obj .response_metadata ["provider" ]
689+ # Check _response_headers
690+ if hasattr (obj , "_response_headers" ):
691+ headers = getattr (obj , "_response_headers" )
689692 if _is_dict_like (headers ):
690693 for k , v in headers .items ():
691694 if isinstance (k , str ) and k .lower () == "x-portkey-provider" and v :
692695 return str (v )
693696 except Exception :
694697 pass
695- # Fallback to model name heuristics
696- return detect_provider_from_model_name (model_name )
698+ return None
697699
698700
699- def detect_provider_from_chunk (chunk : Any , client : "Portkey" , model_name : str ) -> str :
700- """Detect provider for streaming chunks."""
701- # First: check Portkey headers on the client
701+ def detect_provider (obj : Any , client : "Portkey" , model_name : str ) -> str :
702+ """Detect provider from a response or chunk object.
703+
704+ Parameters
705+ ----------
706+ obj : Any
707+ The response or chunk object to extract provider information from.
708+ client : Portkey
709+ The Portkey client instance.
710+ model_name : str
711+ The model name to use as a fallback for provider detection.
712+
713+ Returns
714+ -------
715+ str
716+ The detected provider name.
717+ """
718+ # First: check Portkey headers on the client (authoritative)
702719 provider = _provider_from_portkey_headers (client )
703720 if provider :
704721 return provider
705- # Next: see if chunk exposes any metadata
706- try :
707- if hasattr (chunk , "response_metadata" ) and _is_dict_like (chunk .response_metadata ):
708- if "provider" in chunk .response_metadata :
709- return chunk .response_metadata ["provider" ]
710- if hasattr (chunk , "_response_headers" ):
711- headers = getattr (chunk , "_response_headers" )
712- if _is_dict_like (headers ):
713- for k , v in headers .items ():
714- if isinstance (k , str ) and k .lower () == "x-portkey-provider" and v :
715- return str (v )
716- except Exception :
717- pass
722+ # Next: check object metadata if any
723+ provider = _extract_provider_from_object (obj )
724+ if provider :
725+ return provider
718726 # Fallback to model name heuristics
719727 return detect_provider_from_model_name (model_name )
720728
0 commit comments