@@ -84,6 +84,7 @@ def __init__(
8484 follow_redirects : Optional [bool ] = None ,
8585 max_output_tokens : Optional [int ] = None ,
8686 extra_query : Optional [dict ] = None ,
87+ extra_body : Optional [dict ] = None ,
8788 ):
8889 super ().__init__ (type_ = "openai_http" )
8990 self ._target = target or settings .openai .base_url
@@ -120,6 +121,7 @@ def __init__(
120121 else settings .openai .max_output_tokens
121122 )
122123 self .extra_query = extra_query
124+ self .extra_body = extra_body
123125 self ._async_client : Optional [httpx .AsyncClient ] = None
124126
125127 @property
@@ -242,7 +244,9 @@ async def text_completions( # type: ignore[override]
242244
243245 headers = self ._headers ()
244246 params = self ._params (TEXT_COMPLETIONS )
247+ body = self ._body (TEXT_COMPLETIONS )
245248 payload = self ._completions_payload (
249+ body = body ,
246250 orig_kwargs = kwargs ,
247251 max_output_tokens = output_token_count ,
248252 prompt = prompt ,
@@ -317,10 +321,12 @@ async def chat_completions( # type: ignore[override]
317321 logger .debug ("{} invocation with args: {}" , self .__class__ .__name__ , locals ())
318322 headers = self ._headers ()
319323 params = self ._params (CHAT_COMPLETIONS )
324+ body = self ._body (CHAT_COMPLETIONS )
320325 messages = (
321326 content if raw_content else self ._create_chat_messages (content = content )
322327 )
323328 payload = self ._completions_payload (
329+ body = body ,
324330 orig_kwargs = kwargs ,
325331 max_output_tokens = output_token_count ,
326332 messages = messages ,
@@ -396,10 +402,28 @@ def _params(self, endpoint_type: EndpointType) -> dict[str, str]:
396402
397403 return self .extra_query
398404
405+ def _body (self , endpoint_type : EndpointType ) -> dict [str , str ]:
406+ if self .extra_body is None :
407+ return {}
408+
409+ if (
410+ CHAT_COMPLETIONS in self .extra_body
411+ or MODELS in self .extra_body
412+ or TEXT_COMPLETIONS in self .extra_body
413+ ):
414+ return self .extra_body .get (endpoint_type , {})
415+
416+ return self .extra_body
417+
399418 def _completions_payload (
400- self , orig_kwargs : Optional [dict ], max_output_tokens : Optional [int ], ** kwargs
419+ self ,
420+ body : Optional [dict ],
421+ orig_kwargs : Optional [dict ],
422+ max_output_tokens : Optional [int ],
423+ ** kwargs ,
401424 ) -> dict :
402- payload = orig_kwargs or {}
425+ payload = body or {}
426+ payload .update (orig_kwargs or {})
403427 payload .update (kwargs )
404428 payload ["model" ] = self .model
405429 payload ["stream" ] = True
0 commit comments