From 97507f2619a059e195d541535a2e3a23d0f55e6c Mon Sep 17 00:00:00 2001 From: declark1 <44146800+declark1@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:08:03 -0800 Subject: [PATCH] wip Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> --- .../chat_completions_detection/streaming.rs | 52 +++++++++++++++---- .../completions_detection/streaming.rs | 52 +++++++++++++++---- .../streaming_classification_with_gen.rs | 5 +- .../handlers/streaming_content_detection.rs | 5 +- .../types/detection_batch_stream.rs | 46 +++++++++++----- .../types/detection_batcher/completion.rs | 6 ++- .../detection_batcher/max_processed_index.rs | 6 ++- 7 files changed, 136 insertions(+), 36 deletions(-) diff --git a/src/orchestrator/handlers/chat_completions_detection/streaming.rs b/src/orchestrator/handlers/chat_completions_detection/streaming.rs index 60c5ddd1..b03aeff0 100644 --- a/src/orchestrator/handlers/chat_completions_detection/streaming.rs +++ b/src/orchestrator/handlers/chat_completions_detection/streaming.rs @@ -18,7 +18,7 @@ use std::{collections::HashMap, sync::Arc}; use futures::{StreamExt, future::try_join_all}; use opentelemetry::trace::TraceId; -use tokio::sync::mpsc; +use tokio::sync::{broadcast, mpsc}; use tracing::{Instrument, debug, error, info, instrument, warn}; use uuid::Uuid; @@ -121,7 +121,7 @@ pub async fn handle_streaming( if output_detectors.is_empty() { // No output detectors, forward chat completion chunks to response channel - process_chat_completion_stream(trace_id, chat_completion_stream, None, None, Some(response_tx.clone())).await; + process_chat_completion_stream(trace_id, chat_completion_stream, None, None, response_tx.clone(), true, None).await; info!(%trace_id, "task completed: chat completion stream closed"); } else { // Handle output detection @@ -255,6 +255,8 @@ async fn handle_output_detection( if !chunk_detectors.is_empty() { // Set up streaming detection pipeline + // Create channel to shutdown detection pipeline + let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1); // n represents how many choices to generate for each input message // Choices are processed independently so each choice has its own input channels and detection streams. let n = request.extra.get("n").and_then(|v| v.as_i64()).unwrap_or(1) as usize; @@ -295,12 +297,15 @@ async fn handle_output_detection( chat_completion_stream, Some(completion_state.clone()), Some(input_txs), - None, + response_tx.clone(), + false, + Some(shutdown_tx.clone()), )); // Process detection streams and await completion let detection_batch_stream = DetectionBatchStream::new( CompletionBatcher::new(chunk_detectors.len()), detection_streams, + shutdown_rx, ); process_detection_batch_stream( trace_id, @@ -317,7 +322,9 @@ async fn handle_output_detection( chat_completion_stream, Some(completion_state.clone()), None, - Some(response_tx.clone()), + response_tx.clone(), + true, + None, ) .await; } @@ -368,15 +375,18 @@ async fn process_chat_completion_stream( mut chat_completion_stream: ChatCompletionStream, completion_state: Option>>, input_txs: Option>>>, - response_tx: Option, Error>>>, + response_tx: mpsc::Sender, Error>>, + passthrough: bool, + shutdown_tx: Option>, ) { + let mut no_generated_text = false; while let Some((message_index, result)) = chat_completion_stream.next().await { match result { Ok(Some(chat_completion)) => { // Send chat completion chunk to response channel // NOTE: this forwards chat completion chunks without detections and is only // done here for 2 cases: a) no output detectors b) only whole doc output detectors - if let Some(response_tx) = &response_tx + if passthrough && response_tx .send(Ok(Some(chat_completion.clone()))) .await @@ -385,6 +395,32 @@ async fn process_chat_completion_stream( info!(%trace_id, "task completed: client disconnected"); return; } + + // First message contains finish_reason, no text was generated + if message_index == 0 + && chat_completion + .choices + .first() + .and_then(|choice| choice.finish_reason.as_ref()) + .is_some() + { + warn!(%trace_id, ?chat_completion, "first message contains finish_reason, no text was generated"); + no_generated_text = true; + // Send shutdown signal to detection batch stream + if let Some(shutdown_tx) = &shutdown_tx { + let _ = shutdown_tx.send(()); + } + // Send stop message + let _ = response_tx.send(Ok(Some(chat_completion))).await; + // NOTE: we can't terminate here as the next (final) message contains usage and also needs to be sent + continue; + } + if no_generated_text && chat_completion.usage.is_some() { + // Send usage message and terminate task + let _ = response_tx.send(Ok(Some(chat_completion))).await; + return; + } + if let Some(usage) = &chat_completion.usage && chat_completion.choices.is_empty() { @@ -434,9 +470,7 @@ async fn process_chat_completion_stream( Err(error) => { error!(%trace_id, %error, "task failed: error received from chat completion stream"); // Send error to response channel - if let Some(response_tx) = &response_tx { - let _ = response_tx.send(Err(error.clone())).await; - } + let _ = response_tx.send(Err(error.clone())).await; // Send error to detection input channels if let Some(input_txs) = &input_txs { for input_tx in input_txs.values() { diff --git a/src/orchestrator/handlers/completions_detection/streaming.rs b/src/orchestrator/handlers/completions_detection/streaming.rs index 89d103fd..6f048bb9 100644 --- a/src/orchestrator/handlers/completions_detection/streaming.rs +++ b/src/orchestrator/handlers/completions_detection/streaming.rs @@ -18,7 +18,7 @@ use std::{collections::HashMap, sync::Arc}; use futures::{StreamExt, future::try_join_all}; use opentelemetry::trace::TraceId; -use tokio::sync::mpsc; +use tokio::sync::{broadcast, mpsc}; use tracing::{Instrument, debug, error, info, instrument, warn}; use uuid::Uuid; @@ -120,7 +120,7 @@ pub async fn handle_streaming( if output_detectors.is_empty() { // No output detectors, forward completion chunks to response channel - process_completion_stream(trace_id, completion_stream, None, None, Some(response_tx.clone())).await; + process_completion_stream(trace_id, completion_stream, None, None, response_tx.clone(), true, None).await; info!(%trace_id, "task completed: completion stream closed"); } else { // Handle output detection @@ -239,6 +239,8 @@ async fn handle_output_detection( if !chunk_detectors.is_empty() { // Set up streaming detection pipeline + // Create channel to shutdown detection pipeline + let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1); // n represents how many choices to generate for each input message // Choices are processed independently so each choice has its own input channels and detection streams. let n = request.extra.get("n").and_then(|v| v.as_i64()).unwrap_or(1) as usize; @@ -279,12 +281,15 @@ async fn handle_output_detection( completion_stream, Some(completion_state.clone()), Some(input_txs), - None, + response_tx.clone(), + false, + Some(shutdown_tx.clone()), )); // Process detection streams and await completion let detection_batch_stream = DetectionBatchStream::new( CompletionBatcher::new(chunk_detectors.len()), detection_streams, + shutdown_rx, ); process_detection_batch_stream( trace_id, @@ -301,7 +306,9 @@ async fn handle_output_detection( completion_stream, Some(completion_state.clone()), None, - Some(response_tx.clone()), + response_tx.clone(), + true, + None, ) .await; } @@ -350,15 +357,18 @@ async fn process_completion_stream( mut completion_stream: CompletionStream, completion_state: Option>>, input_txs: Option>>>, - response_tx: Option, Error>>>, + response_tx: mpsc::Sender, Error>>, + passthrough: bool, + shutdown_tx: Option>, ) { + let mut no_generated_text = false; while let Some((message_index, result)) = completion_stream.next().await { match result { Ok(Some(completion)) => { // Send completion chunk to response channel // NOTE: this forwards completion chunks without detections and is only // done here for 2 cases: a) no output detectors b) only whole doc output detectors - if let Some(response_tx) = &response_tx + if passthrough && response_tx .send(Ok(Some(completion.clone()))) .await @@ -367,6 +377,32 @@ async fn process_completion_stream( info!(%trace_id, "task completed: client disconnected"); return; } + + // First message contains finish_reason, no text was generated + if message_index == 0 + && completion + .choices + .first() + .and_then(|choice| choice.finish_reason.as_ref()) + .is_some() + { + warn!(%trace_id, ?completion, "first message contains finish_reason, no text was generated"); + no_generated_text = true; + // Send shutdown signal to detection batch stream + if let Some(shutdown_tx) = &shutdown_tx { + let _ = shutdown_tx.send(()); + } + // Send stop message + let _ = response_tx.send(Ok(Some(completion))).await; + // NOTE: we can't terminate here as the next (final) message contains usage and also needs to be sent + continue; + } + if no_generated_text && completion.usage.is_some() { + // Send usage message and terminate task + let _ = response_tx.send(Ok(Some(completion))).await; + return; + } + if let Some(usage) = &completion.usage && completion.choices.is_empty() { @@ -416,9 +452,7 @@ async fn process_completion_stream( Err(error) => { error!(%trace_id, %error, "task failed: error received from completion stream"); // Send error to response channel - if let Some(response_tx) = &response_tx { - let _ = response_tx.send(Err(error.clone())).await; - } + let _ = response_tx.send(Err(error.clone())).await; // Send error to detection input channels if let Some(input_txs) = &input_txs { for input_tx in input_txs.values() { diff --git a/src/orchestrator/handlers/streaming_classification_with_gen.rs b/src/orchestrator/handlers/streaming_classification_with_gen.rs index 7067f13b..41d59a52 100644 --- a/src/orchestrator/handlers/streaming_classification_with_gen.rs +++ b/src/orchestrator/handlers/streaming_classification_with_gen.rs @@ -23,7 +23,7 @@ use std::{ use futures::StreamExt; use http::HeaderMap; use opentelemetry::trace::TraceId; -use tokio::sync::mpsc; +use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::ReceiverStream; use tracing::{Instrument, error, info, instrument}; @@ -231,6 +231,8 @@ async fn handle_output_detection( let trace_id = task.trace_id; // Create input channel for detection pipeline let (input_tx, input_rx) = mpsc::channel(128); + // Create channel to shutdown detection pipeline + let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1); // Create shared generations let generations: Arc>> = Arc::new(RwLock::new(Vec::new())); @@ -254,6 +256,7 @@ async fn handle_output_detection( let detection_batch_stream = DetectionBatchStream::new( MaxProcessedIndexBatcher::new(detectors.len()), detection_streams, + shutdown_rx, ); process_detection_batch_stream( trace_id, diff --git a/src/orchestrator/handlers/streaming_content_detection.rs b/src/orchestrator/handlers/streaming_content_detection.rs index 6875d936..fbdc5eab 100644 --- a/src/orchestrator/handlers/streaming_content_detection.rs +++ b/src/orchestrator/handlers/streaming_content_detection.rs @@ -19,7 +19,7 @@ use std::{collections::HashMap, pin::Pin, sync::Arc}; use futures::{Stream, StreamExt, stream::Peekable}; use http::HeaderMap; use opentelemetry::trace::TraceId; -use tokio::sync::mpsc; +use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::ReceiverStream; use tracing::{Instrument, error, info, instrument}; @@ -125,6 +125,8 @@ async fn handle_detection( ) { // Create input channel for detection pipeline let (input_tx, input_rx) = mpsc::channel(128); + // Create channel to shutdown detection pipeline + let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1); // Create detection streams let detection_streams = common::text_contents_detection_streams(ctx, headers, detectors.clone(), 0, input_rx).await; @@ -138,6 +140,7 @@ async fn handle_detection( let detection_batch_stream = DetectionBatchStream::new( MaxProcessedIndexBatcher::new(detectors.len()), detection_streams, + shutdown_rx, ); process_detection_batch_stream(trace_id, detection_batch_stream, response_tx) .await; diff --git a/src/orchestrator/types/detection_batch_stream.rs b/src/orchestrator/types/detection_batch_stream.rs index 7d7e9bc3..5a24fd5b 100644 --- a/src/orchestrator/types/detection_batch_stream.rs +++ b/src/orchestrator/types/detection_batch_stream.rs @@ -15,8 +15,8 @@ */ use futures::{Stream, StreamExt, stream}; -use tokio::sync::{mpsc, oneshot}; -use tracing::{debug, error}; +use tokio::sync::{broadcast, mpsc, oneshot}; +use tracing::{debug, error, warn}; use super::{Batch, Chunk, Detection, DetectionBatcher, DetectionStream}; use crate::orchestrator::Error; @@ -31,23 +31,40 @@ pub struct DetectionBatchStream { } impl DetectionBatchStream { - pub fn new(batcher: impl DetectionBatcher, mut streams: Vec) -> Self { + pub fn new( + batcher: impl DetectionBatcher, + mut streams: Vec, + mut shutdown_rx: broadcast::Receiver<()>, + ) -> Self { let (batch_tx, batch_rx) = mpsc::channel(32); // Spawn task to receive detections and process batches tokio::spawn(async move { if streams.len() == 1 { // Skip the batching process for a single detection stream let mut stream = streams.swap_remove(0); - while let Some(msg) = stream.next().await { - match msg { - Ok(batch) => { - debug!(?batch, "sending batch to batch channel"); - let _ = batch_tx.send(Ok(batch)).await; - } - Err(error) => { - error!(?error, "sending error to batch channel"); - let _ = batch_tx.send(Err(error)).await; + loop { + tokio::select! { + // Receive shutdown signal + _ = shutdown_rx.recv() => { + warn!("received shutdown signal, terminating task"); break; + }, + // Receive detections and send to batch channel + msg = stream.next() => { + match msg { + Some(Ok(batch)) => { + debug!(?batch, "sending batch to batch channel"); + let _ = batch_tx.send(Ok(batch)).await; + } + Some(Err(error)) => { + error!(?error, "sending error to batch channel"); + let _ = batch_tx.send(Err(error)).await; + break; + } + None => { + debug!("detections stream has completed"); + } + } } } } @@ -63,6 +80,11 @@ impl DetectionBatchStream { // Disable random branch selection to poll the futures in order biased; + // Receive shutdown signal + _ = shutdown_rx.recv() => { + warn!("received shutdown signal, terminating task"); + break; + }, // Receive detections and push to batcher msg = stream_set.next(), if !stream_completed => { match msg { diff --git a/src/orchestrator/types/detection_batcher/completion.rs b/src/orchestrator/types/detection_batcher/completion.rs index 218f684f..05b045a9 100644 --- a/src/orchestrator/types/detection_batcher/completion.rs +++ b/src/orchestrator/types/detection_batcher/completion.rs @@ -96,7 +96,7 @@ mod test { use std::task::Poll; use futures::StreamExt; - use tokio::sync::mpsc; + use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::ReceiverStream; use super::*; @@ -464,7 +464,9 @@ mod test { // Create detection batch stream let streams = vec![pii_detections_stream, hap_detections_stream]; - let mut detection_batch_stream = DetectionBatchStream::new(batcher, streams); + // Create channel to shutdown detection pipeline + let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1); + let mut detection_batch_stream = DetectionBatchStream::new(batcher, streams, shutdown_rx); for choice_index in 0..choices { // Send chunk-2 detections for pii detector diff --git a/src/orchestrator/types/detection_batcher/max_processed_index.rs b/src/orchestrator/types/detection_batcher/max_processed_index.rs index 363bb53f..67aa727b 100644 --- a/src/orchestrator/types/detection_batcher/max_processed_index.rs +++ b/src/orchestrator/types/detection_batcher/max_processed_index.rs @@ -88,7 +88,7 @@ mod test { use std::task::Poll; use futures::StreamExt; - use tokio::sync::mpsc; + use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::ReceiverStream; use super::*; @@ -287,7 +287,9 @@ mod test { // Create detection batch stream let streams = vec![pii_detections_stream, hap_detections_stream]; - let mut detection_batch_stream = DetectionBatchStream::new(batcher, streams); + // Create channel to shutdown detection pipeline + let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1); + let mut detection_batch_stream = DetectionBatchStream::new(batcher, streams, shutdown_rx); // Send chunk-2 detections for pii detector let _ = pii_detections_tx