Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::{collections::HashMap, sync::Arc};

use futures::{StreamExt, future::try_join_all};
use opentelemetry::trace::TraceId;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc};
use tracing::{Instrument, debug, error, info, instrument, warn};
use uuid::Uuid;

Expand Down Expand Up @@ -121,7 +121,7 @@ pub async fn handle_streaming(

if output_detectors.is_empty() {
// No output detectors, forward chat completion chunks to response channel
process_chat_completion_stream(trace_id, chat_completion_stream, None, None, Some(response_tx.clone())).await;
process_chat_completion_stream(trace_id, chat_completion_stream, None, None, response_tx.clone(), true, None).await;
info!(%trace_id, "task completed: chat completion stream closed");
} else {
// Handle output detection
Expand Down Expand Up @@ -255,6 +255,8 @@ async fn handle_output_detection(

if !chunk_detectors.is_empty() {
// Set up streaming detection pipeline
// Create channel to shutdown detection pipeline
let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
// n represents how many choices to generate for each input message
// Choices are processed independently so each choice has its own input channels and detection streams.
let n = request.extra.get("n").and_then(|v| v.as_i64()).unwrap_or(1) as usize;
Expand Down Expand Up @@ -295,12 +297,15 @@ async fn handle_output_detection(
chat_completion_stream,
Some(completion_state.clone()),
Some(input_txs),
None,
response_tx.clone(),
false,
Some(shutdown_tx.clone()),
));
// Process detection streams and await completion
let detection_batch_stream = DetectionBatchStream::new(
CompletionBatcher::new(chunk_detectors.len()),
detection_streams,
shutdown_rx,
);
process_detection_batch_stream(
trace_id,
Expand All @@ -317,7 +322,9 @@ async fn handle_output_detection(
chat_completion_stream,
Some(completion_state.clone()),
None,
Some(response_tx.clone()),
response_tx.clone(),
true,
None,
)
.await;
}
Expand Down Expand Up @@ -368,15 +375,18 @@ async fn process_chat_completion_stream(
mut chat_completion_stream: ChatCompletionStream,
completion_state: Option<Arc<CompletionState<ChatCompletionChunk>>>,
input_txs: Option<HashMap<u32, mpsc::Sender<Result<(usize, String), Error>>>>,
response_tx: Option<mpsc::Sender<Result<Option<ChatCompletionChunk>, Error>>>,
response_tx: mpsc::Sender<Result<Option<ChatCompletionChunk>, Error>>,
passthrough: bool,
shutdown_tx: Option<broadcast::Sender<()>>,
) {
let mut no_generated_text = false;
while let Some((message_index, result)) = chat_completion_stream.next().await {
match result {
Ok(Some(chat_completion)) => {
// Send chat completion chunk to response channel
// NOTE: this forwards chat completion chunks without detections and is only
// done here for 2 cases: a) no output detectors b) only whole doc output detectors
if let Some(response_tx) = &response_tx
if passthrough
&& response_tx
.send(Ok(Some(chat_completion.clone())))
.await
Expand All @@ -385,6 +395,32 @@ async fn process_chat_completion_stream(
info!(%trace_id, "task completed: client disconnected");
return;
}

// First message contains finish_reason, no text was generated
if message_index == 0
&& chat_completion
.choices
.first()
.and_then(|choice| choice.finish_reason.as_ref())
.is_some()
{
warn!(%trace_id, ?chat_completion, "first message contains finish_reason, no text was generated");
no_generated_text = true;
// Send shutdown signal to detection batch stream
if let Some(shutdown_tx) = &shutdown_tx {
let _ = shutdown_tx.send(());
}
// Send stop message
let _ = response_tx.send(Ok(Some(chat_completion))).await;
// NOTE: we can't terminate here as the next (final) message contains usage and also needs to be sent
continue;
}
if no_generated_text && chat_completion.usage.is_some() {
// Send usage message and terminate task
let _ = response_tx.send(Ok(Some(chat_completion))).await;
return;
}

if let Some(usage) = &chat_completion.usage
&& chat_completion.choices.is_empty()
{
Expand Down Expand Up @@ -434,9 +470,7 @@ async fn process_chat_completion_stream(
Err(error) => {
error!(%trace_id, %error, "task failed: error received from chat completion stream");
// Send error to response channel
if let Some(response_tx) = &response_tx {
let _ = response_tx.send(Err(error.clone())).await;
}
let _ = response_tx.send(Err(error.clone())).await;
// Send error to detection input channels
if let Some(input_txs) = &input_txs {
for input_tx in input_txs.values() {
Expand Down
52 changes: 43 additions & 9 deletions src/orchestrator/handlers/completions_detection/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::{collections::HashMap, sync::Arc};

use futures::{StreamExt, future::try_join_all};
use opentelemetry::trace::TraceId;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc};
use tracing::{Instrument, debug, error, info, instrument, warn};
use uuid::Uuid;

Expand Down Expand Up @@ -120,7 +120,7 @@ pub async fn handle_streaming(

if output_detectors.is_empty() {
// No output detectors, forward completion chunks to response channel
process_completion_stream(trace_id, completion_stream, None, None, Some(response_tx.clone())).await;
process_completion_stream(trace_id, completion_stream, None, None, response_tx.clone(), true, None).await;
info!(%trace_id, "task completed: completion stream closed");
} else {
// Handle output detection
Expand Down Expand Up @@ -239,6 +239,8 @@ async fn handle_output_detection(

if !chunk_detectors.is_empty() {
// Set up streaming detection pipeline
// Create channel to shutdown detection pipeline
let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
// n represents how many choices to generate for each input message
// Choices are processed independently so each choice has its own input channels and detection streams.
let n = request.extra.get("n").and_then(|v| v.as_i64()).unwrap_or(1) as usize;
Expand Down Expand Up @@ -279,12 +281,15 @@ async fn handle_output_detection(
completion_stream,
Some(completion_state.clone()),
Some(input_txs),
None,
response_tx.clone(),
false,
Some(shutdown_tx.clone()),
));
// Process detection streams and await completion
let detection_batch_stream = DetectionBatchStream::new(
CompletionBatcher::new(chunk_detectors.len()),
detection_streams,
shutdown_rx,
);
process_detection_batch_stream(
trace_id,
Expand All @@ -301,7 +306,9 @@ async fn handle_output_detection(
completion_stream,
Some(completion_state.clone()),
None,
Some(response_tx.clone()),
response_tx.clone(),
true,
None,
)
.await;
}
Expand Down Expand Up @@ -350,15 +357,18 @@ async fn process_completion_stream(
mut completion_stream: CompletionStream,
completion_state: Option<Arc<CompletionState<Completion>>>,
input_txs: Option<HashMap<u32, mpsc::Sender<Result<(usize, String), Error>>>>,
response_tx: Option<mpsc::Sender<Result<Option<Completion>, Error>>>,
response_tx: mpsc::Sender<Result<Option<Completion>, Error>>,
passthrough: bool,
shutdown_tx: Option<broadcast::Sender<()>>,
) {
let mut no_generated_text = false;
while let Some((message_index, result)) = completion_stream.next().await {
match result {
Ok(Some(completion)) => {
// Send completion chunk to response channel
// NOTE: this forwards completion chunks without detections and is only
// done here for 2 cases: a) no output detectors b) only whole doc output detectors
if let Some(response_tx) = &response_tx
if passthrough
&& response_tx
.send(Ok(Some(completion.clone())))
.await
Expand All @@ -367,6 +377,32 @@ async fn process_completion_stream(
info!(%trace_id, "task completed: client disconnected");
return;
}

// First message contains finish_reason, no text was generated
if message_index == 0
&& completion
.choices
.first()
.and_then(|choice| choice.finish_reason.as_ref())
.is_some()
{
warn!(%trace_id, ?completion, "first message contains finish_reason, no text was generated");
no_generated_text = true;
// Send shutdown signal to detection batch stream
if let Some(shutdown_tx) = &shutdown_tx {
let _ = shutdown_tx.send(());
}
// Send stop message
let _ = response_tx.send(Ok(Some(completion))).await;
// NOTE: we can't terminate here as the next (final) message contains usage and also needs to be sent
continue;
}
if no_generated_text && completion.usage.is_some() {
// Send usage message and terminate task
let _ = response_tx.send(Ok(Some(completion))).await;
return;
}

if let Some(usage) = &completion.usage
&& completion.choices.is_empty()
{
Expand Down Expand Up @@ -416,9 +452,7 @@ async fn process_completion_stream(
Err(error) => {
error!(%trace_id, %error, "task failed: error received from completion stream");
// Send error to response channel
if let Some(response_tx) = &response_tx {
let _ = response_tx.send(Err(error.clone())).await;
}
let _ = response_tx.send(Err(error.clone())).await;
// Send error to detection input channels
if let Some(input_txs) = &input_txs {
for input_tx in input_txs.values() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::{
use futures::StreamExt;
use http::HeaderMap;
use opentelemetry::trace::TraceId;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::ReceiverStream;
use tracing::{Instrument, error, info, instrument};

Expand Down Expand Up @@ -231,6 +231,8 @@ async fn handle_output_detection(
let trace_id = task.trace_id;
// Create input channel for detection pipeline
let (input_tx, input_rx) = mpsc::channel(128);
// Create channel to shutdown detection pipeline
let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
// Create shared generations
let generations: Arc<RwLock<Vec<ClassifiedGeneratedTextStreamResult>>> =
Arc::new(RwLock::new(Vec::new()));
Expand All @@ -254,6 +256,7 @@ async fn handle_output_detection(
let detection_batch_stream = DetectionBatchStream::new(
MaxProcessedIndexBatcher::new(detectors.len()),
detection_streams,
shutdown_rx,
);
process_detection_batch_stream(
trace_id,
Expand Down
5 changes: 4 additions & 1 deletion src/orchestrator/handlers/streaming_content_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::{collections::HashMap, pin::Pin, sync::Arc};
use futures::{Stream, StreamExt, stream::Peekable};
use http::HeaderMap;
use opentelemetry::trace::TraceId;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::ReceiverStream;
use tracing::{Instrument, error, info, instrument};

Expand Down Expand Up @@ -125,6 +125,8 @@ async fn handle_detection(
) {
// Create input channel for detection pipeline
let (input_tx, input_rx) = mpsc::channel(128);
// Create channel to shutdown detection pipeline
let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
// Create detection streams
let detection_streams =
common::text_contents_detection_streams(ctx, headers, detectors.clone(), 0, input_rx).await;
Expand All @@ -138,6 +140,7 @@ async fn handle_detection(
let detection_batch_stream = DetectionBatchStream::new(
MaxProcessedIndexBatcher::new(detectors.len()),
detection_streams,
shutdown_rx,
);
process_detection_batch_stream(trace_id, detection_batch_stream, response_tx)
.await;
Expand Down
46 changes: 34 additions & 12 deletions src/orchestrator/types/detection_batch_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

*/
use futures::{Stream, StreamExt, stream};
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error};
use tokio::sync::{broadcast, mpsc, oneshot};
use tracing::{debug, error, warn};

use super::{Batch, Chunk, Detection, DetectionBatcher, DetectionStream};
use crate::orchestrator::Error;
Expand All @@ -31,23 +31,40 @@ pub struct DetectionBatchStream {
}

impl DetectionBatchStream {
pub fn new(batcher: impl DetectionBatcher, mut streams: Vec<DetectionStream>) -> Self {
pub fn new(
batcher: impl DetectionBatcher,
mut streams: Vec<DetectionStream>,
mut shutdown_rx: broadcast::Receiver<()>,
) -> Self {
let (batch_tx, batch_rx) = mpsc::channel(32);
// Spawn task to receive detections and process batches
tokio::spawn(async move {
if streams.len() == 1 {
// Skip the batching process for a single detection stream
let mut stream = streams.swap_remove(0);
while let Some(msg) = stream.next().await {
match msg {
Ok(batch) => {
debug!(?batch, "sending batch to batch channel");
let _ = batch_tx.send(Ok(batch)).await;
}
Err(error) => {
error!(?error, "sending error to batch channel");
let _ = batch_tx.send(Err(error)).await;
loop {
tokio::select! {
// Receive shutdown signal
_ = shutdown_rx.recv() => {
warn!("received shutdown signal, terminating task");
break;
},
// Receive detections and send to batch channel
msg = stream.next() => {
match msg {
Some(Ok(batch)) => {
debug!(?batch, "sending batch to batch channel");
let _ = batch_tx.send(Ok(batch)).await;
}
Some(Err(error)) => {
error!(?error, "sending error to batch channel");
let _ = batch_tx.send(Err(error)).await;
break;
}
None => {
debug!("detections stream has completed");
}
}
}
}
}
Expand All @@ -63,6 +80,11 @@ impl DetectionBatchStream {
// Disable random branch selection to poll the futures in order
biased;

// Receive shutdown signal
_ = shutdown_rx.recv() => {
warn!("received shutdown signal, terminating task");
break;
},
// Receive detections and push to batcher
msg = stream_set.next(), if !stream_completed => {
match msg {
Expand Down
6 changes: 4 additions & 2 deletions src/orchestrator/types/detection_batcher/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ mod test {
use std::task::Poll;

use futures::StreamExt;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc};
use tokio_stream::wrappers::ReceiverStream;

use super::*;
Expand Down Expand Up @@ -464,7 +464,9 @@ mod test {

// Create detection batch stream
let streams = vec![pii_detections_stream, hap_detections_stream];
let mut detection_batch_stream = DetectionBatchStream::new(batcher, streams);
// Create channel to shutdown detection pipeline
let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
let mut detection_batch_stream = DetectionBatchStream::new(batcher, streams, shutdown_rx);

for choice_index in 0..choices {
// Send chunk-2 detections for pii detector
Expand Down
Loading