Skip to content

Commit 388289d

Browse files
committed
running precommit hook
1 parent 7c1d125 commit 388289d

File tree

1 file changed

+106
-113
lines changed

1 file changed

+106
-113
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 106 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
from typing import List, Optional, Union
5+
from typing import List, Union
66
from urllib.parse import urlparse
77

88
from tornado.web import HTTPError
99

10+
from ads.aqua import logger
1011
from ads.aqua.client.client import Client, ExtendedRequestError
1112
from ads.aqua.client.openai_client import OpenAI
1213
from ads.aqua.common.decorator import handle_exceptions
@@ -15,7 +16,6 @@
1516
from ads.aqua.extension.errors import Errors
1617
from ads.aqua.modeldeployment import AquaDeploymentApp
1718
from ads.config import COMPARTMENT_OCID
18-
from ads.aqua import logger
1919

2020

2121
class AquaDeploymentHandler(AquaAPIhandler):
@@ -222,7 +222,6 @@ def list_shapes(self):
222222

223223

224224
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
225-
226225
def _extract_text_from_choice(self, choice: dict) -> str:
227226
"""
228227
Extract text content from a single choice structure.
@@ -242,9 +241,9 @@ def _extract_text_from_choice(self, choice: dict) -> str:
242241
`text`, or `content`.
243242
244243
For dicts, the method checks:
245-
• delta → content/text
246-
• message → content/text
247-
• top-level → text/content
244+
• delta → content/text
245+
• message → content/text
246+
• top-level → text/content
248247
249248
For objects, the method checks the same fields via attributes.
250249
@@ -276,7 +275,7 @@ def _extract_text_from_choice(self, choice: dict) -> str:
276275
return getattr(msg, "content", None) or getattr(msg, "text", None)
277276
return getattr(choice, "text", None) or getattr(choice, "content", None)
278277

279-
def _extract_text_from_chunk(self, chunk: dict) -> str :
278+
def _extract_text_from_chunk(self, chunk: dict) -> str:
280279
"""
281280
Extract text content from a model response chunk.
282281
@@ -301,7 +300,7 @@ def _extract_text_from_chunk(self, chunk: dict) -> str :
301300
str
302301
The extracted text if present; otherwise None.
303302
"""
304-
if chunk :
303+
if chunk:
305304
if isinstance(chunk, dict):
306305
choices = chunk.get("choices") or []
307306
if choices:
@@ -314,11 +313,7 @@ def _extract_text_from_chunk(self, chunk: dict) -> str :
314313
return self._extract_text_from_choice(choices[0])
315314
return getattr(chunk, "text", None) or getattr(chunk, "content", None)
316315

317-
def _get_model_deployment_response(
318-
self,
319-
model_deployment_id: str,
320-
payload: dict
321-
):
316+
def _get_model_deployment_response(self, model_deployment_id: str, payload: dict):
322317
"""
323318
Returns the model deployment inference response in a streaming fashion.
324319
@@ -371,7 +366,7 @@ def _get_model_deployment_response(
371366

372367
if missing:
373368
raise HTTPError(400, f"Missing required payload keys: {', '.join(missing)}")
374-
369+
375370
endpoint_type = payload["endpoint_type"]
376371
aqua_client = OpenAI(base_url=endpoint)
377372

@@ -387,148 +382,147 @@ def _get_model_deployment_response(
387382
"user",
388383
"echo",
389384
}
390-
responses_allowed = {
391-
"temperature", "top_p"
392-
}
385+
responses_allowed = {"temperature", "top_p"}
393386

394387
# normalize and filter
395388
if payload.get("stop") == []:
396389
payload["stop"] = None
397390

398391
encoded_image = "NA"
399-
if "encoded_image" in payload :
392+
if "encoded_image" in payload:
400393
encoded_image = payload["encoded_image"]
401394

402395
model = payload.pop("model")
403396
filtered = {k: v for k, v in payload.items() if k in allowed}
404-
responses_filtered = {k: v for k, v in payload.items() if k in responses_allowed}
397+
responses_filtered = {
398+
k: v for k, v in payload.items() if k in responses_allowed
399+
}
405400

406-
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA":
401+
if (
402+
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
403+
and encoded_image == "NA"
404+
):
407405
try:
408406
api_kwargs = {
409407
"model": model,
410408
"messages": [{"role": "user", "content": payload["prompt"]}],
411409
"stream": True,
412-
**filtered
410+
**filtered,
413411
}
414412
if "chat_template" in payload:
415413
chat_template = payload.pop("chat_template")
416414
api_kwargs["extra_body"] = {"chat_template": chat_template}
417-
415+
418416
stream = aqua_client.chat.completions.create(**api_kwargs)
419417

420418
for chunk in stream:
421-
if chunk :
422-
piece = self._extract_text_from_chunk(chunk)
423-
if piece :
424-
yield piece
419+
if chunk:
420+
piece = self._extract_text_from_chunk(chunk)
421+
if piece:
422+
yield piece
425423
except ExtendedRequestError as ex:
426-
raise HTTPError(400, str(ex))
424+
raise HTTPError(400, str(ex)) from ex
427425
except Exception as ex:
428-
raise HTTPError(500, str(ex))
426+
raise HTTPError(500, str(ex)) from ex
429427

430428
elif (
431-
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
432-
and encoded_image != "NA"
433-
):
434-
file_type = payload.pop("file_type")
435-
if file_type.startswith("image"):
436-
api_kwargs = {
437-
"model": model,
438-
"messages": [
439-
{
440-
"role": "user",
441-
"content": [
442-
{"type": "text", "text": payload["prompt"]},
443-
{
444-
"type": "image_url",
445-
"image_url": {"url": f"{encoded_image}"},
446-
},
447-
],
448-
}
449-
],
450-
"stream": True,
451-
**filtered
452-
}
453-
454-
# Add chat_template for image-based chat completions
455-
if "chat_template" in payload:
456-
chat_template = payload.pop("chat_template")
457-
api_kwargs["extra_body"] = {"chat_template": chat_template}
458-
459-
response = aqua_client.chat.completions.create(**api_kwargs)
460-
461-
elif file_type.startswith("audio"):
462-
api_kwargs = {
463-
"model": model,
464-
"messages": [
465-
{
466-
"role": "user",
467-
"content": [
468-
{"type": "text", "text": payload["prompt"]},
469-
{
470-
"type": "audio_url",
471-
"audio_url": {"url": f"{encoded_image}"},
472-
},
473-
],
474-
}
475-
],
476-
"stream": True,
477-
**filtered
478-
}
479-
480-
# Add chat_template for audio-based chat completions
481-
if "chat_template" in payload:
482-
chat_template = payload.pop("chat_template")
483-
api_kwargs["extra_body"] = {"chat_template": chat_template}
484-
485-
response = aqua_client.chat.completions.create(**api_kwargs)
486-
try:
487-
for chunk in response:
488-
piece = self._extract_text_from_chunk(chunk)
489-
if piece:
490-
yield piece
491-
except ExtendedRequestError as ex:
492-
raise HTTPError(400, str(ex))
493-
except Exception as ex:
494-
raise HTTPError(500, str(ex))
429+
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
430+
and encoded_image != "NA"
431+
):
432+
file_type = payload.pop("file_type")
433+
if file_type.startswith("image"):
434+
api_kwargs = {
435+
"model": model,
436+
"messages": [
437+
{
438+
"role": "user",
439+
"content": [
440+
{"type": "text", "text": payload["prompt"]},
441+
{
442+
"type": "image_url",
443+
"image_url": {"url": f"{encoded_image}"},
444+
},
445+
],
446+
}
447+
],
448+
"stream": True,
449+
**filtered,
450+
}
451+
452+
# Add chat_template for image-based chat completions
453+
if "chat_template" in payload:
454+
chat_template = payload.pop("chat_template")
455+
api_kwargs["extra_body"] = {"chat_template": chat_template}
456+
457+
response = aqua_client.chat.completions.create(**api_kwargs)
458+
459+
elif file_type.startswith("audio"):
460+
api_kwargs = {
461+
"model": model,
462+
"messages": [
463+
{
464+
"role": "user",
465+
"content": [
466+
{"type": "text", "text": payload["prompt"]},
467+
{
468+
"type": "audio_url",
469+
"audio_url": {"url": f"{encoded_image}"},
470+
},
471+
],
472+
}
473+
],
474+
"stream": True,
475+
**filtered,
476+
}
477+
478+
# Add chat_template for audio-based chat completions
479+
if "chat_template" in payload:
480+
chat_template = payload.pop("chat_template")
481+
api_kwargs["extra_body"] = {"chat_template": chat_template}
482+
483+
response = aqua_client.chat.completions.create(**api_kwargs)
484+
try:
485+
for chunk in response:
486+
piece = self._extract_text_from_chunk(chunk)
487+
if piece:
488+
yield piece
489+
except ExtendedRequestError as ex:
490+
raise HTTPError(400, str(ex)) from ex
491+
except Exception as ex:
492+
raise HTTPError(500, str(ex)) from ex
495493
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
496494
try:
497495
for chunk in aqua_client.completions.create(
498496
prompt=payload["prompt"], stream=True, model=model, **filtered
499497
):
500-
if chunk :
501-
piece = self._extract_text_from_chunk(chunk)
502-
if piece :
503-
yield piece
498+
if chunk:
499+
piece = self._extract_text_from_chunk(chunk)
500+
if piece:
501+
yield piece
504502
except ExtendedRequestError as ex:
505-
raise HTTPError(400, str(ex))
503+
raise HTTPError(400, str(ex)) from ex
506504
except Exception as ex:
507-
raise HTTPError(500, str(ex))
505+
raise HTTPError(500, str(ex)) from ex
508506

509507
elif endpoint_type == PredictEndpoints.RESPONSES:
510-
api_kwargs = {
511-
"model": model,
512-
"input": payload["prompt"],
513-
"stream": True
514-
}
508+
kwargs = {"model": model, "input": payload["prompt"], "stream": True}
515509

516510
if "temperature" in responses_filtered:
517-
api_kwargs["temperature"] = responses_filtered["temperature"]
511+
kwargs["temperature"] = responses_filtered["temperature"]
518512
if "top_p" in responses_filtered:
519-
api_kwargs["top_p"] = responses_filtered["top_p"]
513+
kwargs["top_p"] = responses_filtered["top_p"]
520514

521-
response = aqua_client.responses.create(**api_kwargs)
515+
response = aqua_client.responses.create(**kwargs)
522516
try:
523517
for chunk in response:
524-
if chunk :
525-
piece = self._extract_text_from_chunk(chunk)
526-
if piece :
527-
yield piece
518+
if chunk:
519+
piece = self._extract_text_from_chunk(chunk)
520+
if piece:
521+
yield piece
528522
except ExtendedRequestError as ex:
529-
raise HTTPError(400, str(ex))
523+
raise HTTPError(400, str(ex)) from ex
530524
except Exception as ex:
531-
raise HTTPError(500, str(ex))
525+
raise HTTPError(500, str(ex)) from ex
532526
else:
533527
raise HTTPError(400, f"Unsupported endpoint_type: {endpoint_type}")
534528

@@ -552,7 +546,6 @@ def post(self, model_deployment_id):
552546
prompt = input_data.get("prompt")
553547
messages = input_data.get("messages")
554548

555-
556549
if not prompt and not messages:
557550
raise HTTPError(
558551
400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages")

0 commit comments

Comments
 (0)