Skip to content

Commit 3bb5ff2

Browse files
authored
Removing kernal messaging in aqua (#1304)
2 parents ab35ba2 + e43d541 commit 3bb5ff2

File tree

3 files changed

+332
-50
lines changed

3 files changed

+332
-50
lines changed

ads/aqua/common/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class PredictEndpoints(ExtendedEnum):
2424
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
2525
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
2626
EMBEDDING_ENDPOINT = "/v1/embedding"
27+
RESPONSES = "/v1/responses"
2728

2829

2930
class Tags(ExtendedEnum):

ads/aqua/extension/deployment_handler.py

Lines changed: 253 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
from typing import List, Optional, Union
5+
from typing import List, Union
66
from urllib.parse import urlparse
77

88
from tornado.web import HTTPError
99

10-
from ads.aqua.app import logger
10+
from ads.aqua import logger
1111
from ads.aqua.client.client import Client, ExtendedRequestError
12+
from ads.aqua.client.openai_client import OpenAI
1213
from ads.aqua.common.decorator import handle_exceptions
1314
from ads.aqua.common.enums import PredictEndpoints
1415
from ads.aqua.extension.base_handler import AquaAPIhandler
@@ -221,12 +222,98 @@ def list_shapes(self):
221222

222223

223224
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
224-
def _get_model_deployment_response(
225-
self,
226-
model_deployment_id: str,
227-
payload: dict,
228-
route_override_header: Optional[str],
229-
):
225+
def _extract_text_from_choice(self, choice: dict) -> str:
226+
"""
227+
Extract text content from a single choice structure.
228+
229+
Handles both dictionary-based API responses and object-based SDK responses.
230+
For dict choices, it checks delta-based streaming fields, message-based
231+
non-streaming fields, and finally top-level text/content keys.
232+
For object choices, it inspects `.delta`, `.message`, and top-level
233+
`.text` or `.content` attributes.
234+
235+
Parameters
236+
----------
237+
choice : dict
238+
A choice entry from a model response. It may be:
239+
- A dict originating from a JSON API response (streaming or non-streaming).
240+
- An SDK-style object with attributes such as `delta`, `message`,
241+
`text`, or `content`.
242+
243+
For dicts, the method checks:
244+
• delta → content/text
245+
• message → content/text
246+
• top-level → text/content
247+
248+
For objects, the method checks the same fields via attributes.
249+
250+
Returns
251+
-------
252+
str | None:
253+
The extracted text if present; otherwise None.
254+
"""
255+
# choice may be a dict or an object
256+
if isinstance(choice, dict):
257+
# streaming chunk: {"delta": {"content": "..."}}
258+
delta = choice.get("delta")
259+
if isinstance(delta, dict):
260+
return delta.get("content") or delta.get("text") or None
261+
# non-streaming: {"message": {"content": "..."}}
262+
msg = choice.get("message")
263+
if isinstance(msg, dict):
264+
return msg.get("content") or msg.get("text")
265+
# fallback top-level fields
266+
return choice.get("text") or choice.get("content")
267+
# object-like choice
268+
delta = getattr(choice, "delta", None)
269+
if delta is not None:
270+
return getattr(delta, "content", None) or getattr(delta, "text", None)
271+
msg = getattr(choice, "message", None)
272+
if msg is not None:
273+
if isinstance(msg, str):
274+
return msg
275+
return getattr(msg, "content", None) or getattr(msg, "text", None)
276+
return getattr(choice, "text", None) or getattr(choice, "content", None)
277+
278+
def _extract_text_from_chunk(self, chunk: dict) -> str:
279+
"""
280+
Extract text content from a model response chunk.
281+
282+
Supports both dict-form chunks (streaming or non-streaming) and SDK-style
283+
object chunks. When choices are present, extraction is delegated to
284+
`_extract_text_from_choice`. If no choices exist, top-level text/content
285+
fields or attributes are used.
286+
287+
Parameters
288+
----------
289+
chunk : dict
290+
A chunk returned from a model stream or full response. It may be:
291+
- A dict containing a `choices` list or top-level text/content fields.
292+
- An SDK-style object with a `choices` attribute or top-level
293+
`text`/`content` attributes.
294+
295+
If `choices` is present, the method extracts text from the first
296+
choice using `_extract_text_from_choice`. Otherwise, it falls back
297+
to top-level text/content.
298+
Returns
299+
-------
300+
str
301+
The extracted text if present; otherwise None.
302+
"""
303+
if chunk:
304+
if isinstance(chunk, dict):
305+
choices = chunk.get("choices") or []
306+
if choices:
307+
return self._extract_text_from_choice(choices[0])
308+
# fallback top-level
309+
return chunk.get("text") or chunk.get("content")
310+
# object-like chunk
311+
choices = getattr(chunk, "choices", None)
312+
if choices:
313+
return self._extract_text_from_choice(choices[0])
314+
return getattr(chunk, "text", None) or getattr(chunk, "content", None)
315+
316+
def _get_model_deployment_response(self, model_deployment_id: str, payload: dict):
230317
"""
231318
Returns the model deployment inference response in a streaming fashion.
232319
@@ -272,53 +359,172 @@ def _get_model_deployment_response(
272359
"""
273360

274361
model_deployment = AquaDeploymentApp().get(model_deployment_id)
275-
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
276-
endpoint_type = model_deployment.environment_variables.get(
277-
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
278-
)
279-
aqua_client = Client(endpoint=endpoint)
280-
281-
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
282-
endpoint_type,
283-
route_override_header,
362+
endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1"
363+
364+
required_keys = ["endpoint_type", "prompt", "model"]
365+
missing = [k for k in required_keys if k not in payload]
366+
367+
if missing:
368+
raise HTTPError(400, f"Missing required payload keys: {', '.join(missing)}")
369+
370+
endpoint_type = payload["endpoint_type"]
371+
aqua_client = OpenAI(base_url=endpoint)
372+
373+
allowed = {
374+
"max_tokens",
375+
"temperature",
376+
"top_p",
377+
"stop",
378+
"n",
379+
"presence_penalty",
380+
"frequency_penalty",
381+
"logprobs",
382+
"user",
383+
"echo",
384+
}
385+
responses_allowed = {"temperature", "top_p"}
386+
387+
# normalize and filter
388+
if payload.get("stop") == []:
389+
payload["stop"] = None
390+
391+
encoded_image = "NA"
392+
if "encoded_image" in payload:
393+
encoded_image = payload["encoded_image"]
394+
395+
model = payload.pop("model")
396+
filtered = {k: v for k, v in payload.items() if k in allowed}
397+
responses_filtered = {
398+
k: v for k, v in payload.items() if k in responses_allowed
399+
}
400+
401+
if (
402+
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
403+
and encoded_image == "NA"
284404
):
285405
try:
286-
for chunk in aqua_client.chat(
287-
messages=payload.pop("messages"),
288-
payload=payload,
289-
stream=True,
290-
):
291-
try:
292-
if "text" in chunk["choices"][0]:
293-
yield chunk["choices"][0]["text"]
294-
elif "content" in chunk["choices"][0]["delta"]:
295-
yield chunk["choices"][0]["delta"]["content"]
296-
except Exception as e:
297-
logger.debug(
298-
f"Exception occurred while parsing streaming response: {e}"
299-
)
406+
api_kwargs = {
407+
"model": model,
408+
"messages": [{"role": "user", "content": payload["prompt"]}],
409+
"stream": True,
410+
**filtered,
411+
}
412+
if "chat_template" in payload:
413+
chat_template = payload.pop("chat_template")
414+
api_kwargs["extra_body"] = {"chat_template": chat_template}
415+
416+
stream = aqua_client.chat.completions.create(**api_kwargs)
417+
418+
for chunk in stream:
419+
if chunk:
420+
piece = self._extract_text_from_chunk(chunk)
421+
if piece:
422+
yield piece
300423
except ExtendedRequestError as ex:
301-
raise HTTPError(400, str(ex))
424+
raise HTTPError(400, str(ex)) from ex
302425
except Exception as ex:
303-
raise HTTPError(500, str(ex))
426+
raise HTTPError(500, str(ex)) from ex
427+
428+
elif (
429+
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
430+
and encoded_image != "NA"
431+
):
432+
file_type = payload.pop("file_type")
433+
if file_type.startswith("image"):
434+
api_kwargs = {
435+
"model": model,
436+
"messages": [
437+
{
438+
"role": "user",
439+
"content": [
440+
{"type": "text", "text": payload["prompt"]},
441+
{
442+
"type": "image_url",
443+
"image_url": {"url": f"{encoded_image}"},
444+
},
445+
],
446+
}
447+
],
448+
"stream": True,
449+
**filtered,
450+
}
451+
452+
# Add chat_template for image-based chat completions
453+
if "chat_template" in payload:
454+
chat_template = payload.pop("chat_template")
455+
api_kwargs["extra_body"] = {"chat_template": chat_template}
456+
457+
response = aqua_client.chat.completions.create(**api_kwargs)
458+
459+
elif file_type.startswith("audio"):
460+
api_kwargs = {
461+
"model": model,
462+
"messages": [
463+
{
464+
"role": "user",
465+
"content": [
466+
{"type": "text", "text": payload["prompt"]},
467+
{
468+
"type": "audio_url",
469+
"audio_url": {"url": f"{encoded_image}"},
470+
},
471+
],
472+
}
473+
],
474+
"stream": True,
475+
**filtered,
476+
}
477+
478+
# Add chat_template for audio-based chat completions
479+
if "chat_template" in payload:
480+
chat_template = payload.pop("chat_template")
481+
api_kwargs["extra_body"] = {"chat_template": chat_template}
304482

483+
response = aqua_client.chat.completions.create(**api_kwargs)
484+
try:
485+
for chunk in response:
486+
piece = self._extract_text_from_chunk(chunk)
487+
if piece:
488+
yield piece
489+
except ExtendedRequestError as ex:
490+
raise HTTPError(400, str(ex)) from ex
491+
except Exception as ex:
492+
raise HTTPError(500, str(ex)) from ex
305493
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
306494
try:
307-
for chunk in aqua_client.generate(
308-
prompt=payload.pop("prompt"),
309-
payload=payload,
310-
stream=True,
495+
for chunk in aqua_client.completions.create(
496+
prompt=payload["prompt"], stream=True, model=model, **filtered
311497
):
312-
try:
313-
yield chunk["choices"][0]["text"]
314-
except Exception as e:
315-
logger.debug(
316-
f"Exception occurred while parsing streaming response: {e}"
317-
)
498+
if chunk:
499+
piece = self._extract_text_from_chunk(chunk)
500+
if piece:
501+
yield piece
502+
except ExtendedRequestError as ex:
503+
raise HTTPError(400, str(ex)) from ex
504+
except Exception as ex:
505+
raise HTTPError(500, str(ex)) from ex
506+
507+
elif endpoint_type == PredictEndpoints.RESPONSES:
508+
kwargs = {"model": model, "input": payload["prompt"], "stream": True}
509+
510+
if "temperature" in responses_filtered:
511+
kwargs["temperature"] = responses_filtered["temperature"]
512+
if "top_p" in responses_filtered:
513+
kwargs["top_p"] = responses_filtered["top_p"]
514+
515+
response = aqua_client.responses.create(**kwargs)
516+
try:
517+
for chunk in response:
518+
if chunk:
519+
piece = self._extract_text_from_chunk(chunk)
520+
if piece:
521+
yield piece
318522
except ExtendedRequestError as ex:
319-
raise HTTPError(400, str(ex))
523+
raise HTTPError(400, str(ex)) from ex
320524
except Exception as ex:
321-
raise HTTPError(500, str(ex))
525+
raise HTTPError(500, str(ex)) from ex
526+
else:
527+
raise HTTPError(400, f"Unsupported endpoint_type: {endpoint_type}")
322528

323529
@handle_exceptions
324530
def post(self, model_deployment_id):
@@ -346,18 +552,17 @@ def post(self, model_deployment_id):
346552
)
347553
if not input_data.get("model"):
348554
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
349-
route_override_header = self.request.headers.get("route", None)
350555
self.set_header("Content-Type", "text/event-stream")
351556
response_gen = self._get_model_deployment_response(
352-
model_deployment_id, input_data, route_override_header
557+
model_deployment_id, input_data
353558
)
354559
try:
355560
for chunk in response_gen:
356561
self.write(chunk)
357562
self.flush()
358563
self.finish()
359564
except Exception as ex:
360-
self.set_status(ex.status_code)
565+
self.set_status(getattr(ex, "status_code", 500))
361566
self.write({"message": "Error occurred", "reason": str(ex)})
362567
self.finish()
363568

0 commit comments

Comments
 (0)