|
2 | 2 | # Copyright (c) 2024, 2025 Oracle and/or its affiliates. |
3 | 3 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
4 | 4 |
|
5 | | -from typing import List, Optional, Union |
| 5 | +from typing import List, Union |
6 | 6 | from urllib.parse import urlparse |
7 | 7 |
|
8 | 8 | from tornado.web import HTTPError |
9 | 9 |
|
10 | | -from ads.aqua.app import logger |
| 10 | +from ads.aqua import logger |
11 | 11 | from ads.aqua.client.client import Client, ExtendedRequestError |
| 12 | +from ads.aqua.client.openai_client import OpenAI |
12 | 13 | from ads.aqua.common.decorator import handle_exceptions |
13 | 14 | from ads.aqua.common.enums import PredictEndpoints |
14 | 15 | from ads.aqua.extension.base_handler import AquaAPIhandler |
@@ -221,12 +222,98 @@ def list_shapes(self): |
221 | 222 |
|
222 | 223 |
|
223 | 224 | class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): |
224 | | - def _get_model_deployment_response( |
225 | | - self, |
226 | | - model_deployment_id: str, |
227 | | - payload: dict, |
228 | | - route_override_header: Optional[str], |
229 | | - ): |
| 225 | + def _extract_text_from_choice(self, choice: dict) -> str: |
| 226 | + """ |
| 227 | + Extract text content from a single choice structure. |
| 228 | +
|
| 229 | + Handles both dictionary-based API responses and object-based SDK responses. |
| 230 | + For dict choices, it checks delta-based streaming fields, message-based |
| 231 | + non-streaming fields, and finally top-level text/content keys. |
| 232 | + For object choices, it inspects `.delta`, `.message`, and top-level |
| 233 | + `.text` or `.content` attributes. |
| 234 | +
|
| 235 | + Parameters |
| 236 | + ---------- |
| 237 | + choice : dict |
| 238 | + A choice entry from a model response. It may be: |
| 239 | + - A dict originating from a JSON API response (streaming or non-streaming). |
| 240 | + - An SDK-style object with attributes such as `delta`, `message`, |
| 241 | + `text`, or `content`. |
| 242 | +
|
| 243 | + For dicts, the method checks: |
| 244 | + • delta → content/text |
| 245 | + • message → content/text |
| 246 | + • top-level → text/content |
| 247 | +
|
| 248 | + For objects, the method checks the same fields via attributes. |
| 249 | +
|
| 250 | + Returns |
| 251 | + ------- |
| 252 | + str | None: |
| 253 | + The extracted text if present; otherwise None. |
| 254 | + """ |
| 255 | + # choice may be a dict or an object |
| 256 | + if isinstance(choice, dict): |
| 257 | + # streaming chunk: {"delta": {"content": "..."}} |
| 258 | + delta = choice.get("delta") |
| 259 | + if isinstance(delta, dict): |
| 260 | + return delta.get("content") or delta.get("text") or None |
| 261 | + # non-streaming: {"message": {"content": "..."}} |
| 262 | + msg = choice.get("message") |
| 263 | + if isinstance(msg, dict): |
| 264 | + return msg.get("content") or msg.get("text") |
| 265 | + # fallback top-level fields |
| 266 | + return choice.get("text") or choice.get("content") |
| 267 | + # object-like choice |
| 268 | + delta = getattr(choice, "delta", None) |
| 269 | + if delta is not None: |
| 270 | + return getattr(delta, "content", None) or getattr(delta, "text", None) |
| 271 | + msg = getattr(choice, "message", None) |
| 272 | + if msg is not None: |
| 273 | + if isinstance(msg, str): |
| 274 | + return msg |
| 275 | + return getattr(msg, "content", None) or getattr(msg, "text", None) |
| 276 | + return getattr(choice, "text", None) or getattr(choice, "content", None) |
| 277 | + |
| 278 | + def _extract_text_from_chunk(self, chunk: dict) -> str: |
| 279 | + """ |
| 280 | + Extract text content from a model response chunk. |
| 281 | +
|
| 282 | + Supports both dict-form chunks (streaming or non-streaming) and SDK-style |
| 283 | + object chunks. When choices are present, extraction is delegated to |
| 284 | + `_extract_text_from_choice`. If no choices exist, top-level text/content |
| 285 | + fields or attributes are used. |
| 286 | +
|
| 287 | + Parameters |
| 288 | + ---------- |
| 289 | + chunk : dict |
| 290 | + A chunk returned from a model stream or full response. It may be: |
| 291 | + - A dict containing a `choices` list or top-level text/content fields. |
| 292 | + - An SDK-style object with a `choices` attribute or top-level |
| 293 | + `text`/`content` attributes. |
| 294 | +
|
| 295 | + If `choices` is present, the method extracts text from the first |
| 296 | + choice using `_extract_text_from_choice`. Otherwise, it falls back |
| 297 | + to top-level text/content. |
| 298 | + Returns |
| 299 | + ------- |
| 300 | + str |
| 301 | + The extracted text if present; otherwise None. |
| 302 | + """ |
| 303 | + if chunk: |
| 304 | + if isinstance(chunk, dict): |
| 305 | + choices = chunk.get("choices") or [] |
| 306 | + if choices: |
| 307 | + return self._extract_text_from_choice(choices[0]) |
| 308 | + # fallback top-level |
| 309 | + return chunk.get("text") or chunk.get("content") |
| 310 | + # object-like chunk |
| 311 | + choices = getattr(chunk, "choices", None) |
| 312 | + if choices: |
| 313 | + return self._extract_text_from_choice(choices[0]) |
| 314 | + return getattr(chunk, "text", None) or getattr(chunk, "content", None) |
| 315 | + |
| 316 | + def _get_model_deployment_response(self, model_deployment_id: str, payload: dict): |
230 | 317 | """ |
231 | 318 | Returns the model deployment inference response in a streaming fashion. |
232 | 319 |
|
@@ -272,53 +359,172 @@ def _get_model_deployment_response( |
272 | 359 | """ |
273 | 360 |
|
274 | 361 | model_deployment = AquaDeploymentApp().get(model_deployment_id) |
275 | | - endpoint = model_deployment.endpoint + "/predictWithResponseStream" |
276 | | - endpoint_type = model_deployment.environment_variables.get( |
277 | | - "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT |
278 | | - ) |
279 | | - aqua_client = Client(endpoint=endpoint) |
280 | | - |
281 | | - if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( |
282 | | - endpoint_type, |
283 | | - route_override_header, |
| 362 | + endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1" |
| 363 | + |
| 364 | + required_keys = ["endpoint_type", "prompt", "model"] |
| 365 | + missing = [k for k in required_keys if k not in payload] |
| 366 | + |
| 367 | + if missing: |
| 368 | + raise HTTPError(400, f"Missing required payload keys: {', '.join(missing)}") |
| 369 | + |
| 370 | + endpoint_type = payload["endpoint_type"] |
| 371 | + aqua_client = OpenAI(base_url=endpoint) |
| 372 | + |
| 373 | + allowed = { |
| 374 | + "max_tokens", |
| 375 | + "temperature", |
| 376 | + "top_p", |
| 377 | + "stop", |
| 378 | + "n", |
| 379 | + "presence_penalty", |
| 380 | + "frequency_penalty", |
| 381 | + "logprobs", |
| 382 | + "user", |
| 383 | + "echo", |
| 384 | + } |
| 385 | + responses_allowed = {"temperature", "top_p"} |
| 386 | + |
| 387 | + # normalize and filter |
| 388 | + if payload.get("stop") == []: |
| 389 | + payload["stop"] = None |
| 390 | + |
| 391 | + encoded_image = "NA" |
| 392 | + if "encoded_image" in payload: |
| 393 | + encoded_image = payload["encoded_image"] |
| 394 | + |
| 395 | + model = payload.pop("model") |
| 396 | + filtered = {k: v for k, v in payload.items() if k in allowed} |
| 397 | + responses_filtered = { |
| 398 | + k: v for k, v in payload.items() if k in responses_allowed |
| 399 | + } |
| 400 | + |
| 401 | + if ( |
| 402 | + endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT |
| 403 | + and encoded_image == "NA" |
284 | 404 | ): |
285 | 405 | try: |
286 | | - for chunk in aqua_client.chat( |
287 | | - messages=payload.pop("messages"), |
288 | | - payload=payload, |
289 | | - stream=True, |
290 | | - ): |
291 | | - try: |
292 | | - if "text" in chunk["choices"][0]: |
293 | | - yield chunk["choices"][0]["text"] |
294 | | - elif "content" in chunk["choices"][0]["delta"]: |
295 | | - yield chunk["choices"][0]["delta"]["content"] |
296 | | - except Exception as e: |
297 | | - logger.debug( |
298 | | - f"Exception occurred while parsing streaming response: {e}" |
299 | | - ) |
| 406 | + api_kwargs = { |
| 407 | + "model": model, |
| 408 | + "messages": [{"role": "user", "content": payload["prompt"]}], |
| 409 | + "stream": True, |
| 410 | + **filtered, |
| 411 | + } |
| 412 | + if "chat_template" in payload: |
| 413 | + chat_template = payload.pop("chat_template") |
| 414 | + api_kwargs["extra_body"] = {"chat_template": chat_template} |
| 415 | + |
| 416 | + stream = aqua_client.chat.completions.create(**api_kwargs) |
| 417 | + |
| 418 | + for chunk in stream: |
| 419 | + if chunk: |
| 420 | + piece = self._extract_text_from_chunk(chunk) |
| 421 | + if piece: |
| 422 | + yield piece |
300 | 423 | except ExtendedRequestError as ex: |
301 | | - raise HTTPError(400, str(ex)) |
| 424 | + raise HTTPError(400, str(ex)) from ex |
302 | 425 | except Exception as ex: |
303 | | - raise HTTPError(500, str(ex)) |
| 426 | + raise HTTPError(500, str(ex)) from ex |
| 427 | + |
| 428 | + elif ( |
| 429 | + endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT |
| 430 | + and encoded_image != "NA" |
| 431 | + ): |
| 432 | + file_type = payload.pop("file_type") |
| 433 | + if file_type.startswith("image"): |
| 434 | + api_kwargs = { |
| 435 | + "model": model, |
| 436 | + "messages": [ |
| 437 | + { |
| 438 | + "role": "user", |
| 439 | + "content": [ |
| 440 | + {"type": "text", "text": payload["prompt"]}, |
| 441 | + { |
| 442 | + "type": "image_url", |
| 443 | + "image_url": {"url": f"{encoded_image}"}, |
| 444 | + }, |
| 445 | + ], |
| 446 | + } |
| 447 | + ], |
| 448 | + "stream": True, |
| 449 | + **filtered, |
| 450 | + } |
| 451 | + |
| 452 | + # Add chat_template for image-based chat completions |
| 453 | + if "chat_template" in payload: |
| 454 | + chat_template = payload.pop("chat_template") |
| 455 | + api_kwargs["extra_body"] = {"chat_template": chat_template} |
| 456 | + |
| 457 | + response = aqua_client.chat.completions.create(**api_kwargs) |
| 458 | + |
| 459 | + elif file_type.startswith("audio"): |
| 460 | + api_kwargs = { |
| 461 | + "model": model, |
| 462 | + "messages": [ |
| 463 | + { |
| 464 | + "role": "user", |
| 465 | + "content": [ |
| 466 | + {"type": "text", "text": payload["prompt"]}, |
| 467 | + { |
| 468 | + "type": "audio_url", |
| 469 | + "audio_url": {"url": f"{encoded_image}"}, |
| 470 | + }, |
| 471 | + ], |
| 472 | + } |
| 473 | + ], |
| 474 | + "stream": True, |
| 475 | + **filtered, |
| 476 | + } |
| 477 | + |
| 478 | + # Add chat_template for audio-based chat completions |
| 479 | + if "chat_template" in payload: |
| 480 | + chat_template = payload.pop("chat_template") |
| 481 | + api_kwargs["extra_body"] = {"chat_template": chat_template} |
304 | 482 |
|
| 483 | + response = aqua_client.chat.completions.create(**api_kwargs) |
| 484 | + try: |
| 485 | + for chunk in response: |
| 486 | + piece = self._extract_text_from_chunk(chunk) |
| 487 | + if piece: |
| 488 | + yield piece |
| 489 | + except ExtendedRequestError as ex: |
| 490 | + raise HTTPError(400, str(ex)) from ex |
| 491 | + except Exception as ex: |
| 492 | + raise HTTPError(500, str(ex)) from ex |
305 | 493 | elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: |
306 | 494 | try: |
307 | | - for chunk in aqua_client.generate( |
308 | | - prompt=payload.pop("prompt"), |
309 | | - payload=payload, |
310 | | - stream=True, |
| 495 | + for chunk in aqua_client.completions.create( |
| 496 | + prompt=payload["prompt"], stream=True, model=model, **filtered |
311 | 497 | ): |
312 | | - try: |
313 | | - yield chunk["choices"][0]["text"] |
314 | | - except Exception as e: |
315 | | - logger.debug( |
316 | | - f"Exception occurred while parsing streaming response: {e}" |
317 | | - ) |
| 498 | + if chunk: |
| 499 | + piece = self._extract_text_from_chunk(chunk) |
| 500 | + if piece: |
| 501 | + yield piece |
| 502 | + except ExtendedRequestError as ex: |
| 503 | + raise HTTPError(400, str(ex)) from ex |
| 504 | + except Exception as ex: |
| 505 | + raise HTTPError(500, str(ex)) from ex |
| 506 | + |
| 507 | + elif endpoint_type == PredictEndpoints.RESPONSES: |
| 508 | + kwargs = {"model": model, "input": payload["prompt"], "stream": True} |
| 509 | + |
| 510 | + if "temperature" in responses_filtered: |
| 511 | + kwargs["temperature"] = responses_filtered["temperature"] |
| 512 | + if "top_p" in responses_filtered: |
| 513 | + kwargs["top_p"] = responses_filtered["top_p"] |
| 514 | + |
| 515 | + response = aqua_client.responses.create(**kwargs) |
| 516 | + try: |
| 517 | + for chunk in response: |
| 518 | + if chunk: |
| 519 | + piece = self._extract_text_from_chunk(chunk) |
| 520 | + if piece: |
| 521 | + yield piece |
318 | 522 | except ExtendedRequestError as ex: |
319 | | - raise HTTPError(400, str(ex)) |
| 523 | + raise HTTPError(400, str(ex)) from ex |
320 | 524 | except Exception as ex: |
321 | | - raise HTTPError(500, str(ex)) |
| 525 | + raise HTTPError(500, str(ex)) from ex |
| 526 | + else: |
| 527 | + raise HTTPError(400, f"Unsupported endpoint_type: {endpoint_type}") |
322 | 528 |
|
323 | 529 | @handle_exceptions |
324 | 530 | def post(self, model_deployment_id): |
@@ -346,18 +552,17 @@ def post(self, model_deployment_id): |
346 | 552 | ) |
347 | 553 | if not input_data.get("model"): |
348 | 554 | raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) |
349 | | - route_override_header = self.request.headers.get("route", None) |
350 | 555 | self.set_header("Content-Type", "text/event-stream") |
351 | 556 | response_gen = self._get_model_deployment_response( |
352 | | - model_deployment_id, input_data, route_override_header |
| 557 | + model_deployment_id, input_data |
353 | 558 | ) |
354 | 559 | try: |
355 | 560 | for chunk in response_gen: |
356 | 561 | self.write(chunk) |
357 | 562 | self.flush() |
358 | 563 | self.finish() |
359 | 564 | except Exception as ex: |
360 | | - self.set_status(ex.status_code) |
| 565 | + self.set_status(getattr(ex, "status_code", 500)) |
361 | 566 | self.write({"message": "Error occurred", "reason": str(ex)}) |
362 | 567 | self.finish() |
363 | 568 |
|
|
0 commit comments