@@ -656,10 +656,15 @@ def __init__(self, proxy: CopilotProvider):
656656 self .stream_queue : Optional [asyncio .Queue ] = None
657657 self .processing_task : Optional [asyncio .Task ] = None
658658
659+ self .finish_stream = False
660+
661+ # For debugging only
662+ # self.data_sent = []
663+
659664 def connection_made (self , transport : asyncio .Transport ) -> None :
660665 """Handle successful connection to target"""
661666 self .transport = transport
662- logger .debug (f"Target transport peer : { transport .get_extra_info ('peername' )} " )
667+ logger .debug (f"Connection established to target : { transport .get_extra_info ('peername' )} " )
663668 self .proxy .target_transport = transport
664669
665670 def _ensure_output_processor (self ) -> None :
@@ -688,7 +693,7 @@ async def _process_stream(self):
688693 try :
689694
690695 async def stream_iterator ():
691- while True :
696+ while not self . stream_queue . empty () :
692697 incoming_record = await self .stream_queue .get ()
693698
694699 record_content = incoming_record .get ("content" , {})
@@ -701,6 +706,9 @@ async def stream_iterator():
701706 else :
702707 content = choice .get ("delta" , {}).get ("content" )
703708
709+ if choice .get ("finish_reason" , None ) == "stop" :
710+ self .finish_stream = True
711+
704712 streaming_choices .append (
705713 StreamingChoices (
706714 finish_reason = choice .get ("finish_reason" , None ),
@@ -722,22 +730,18 @@ async def stream_iterator():
722730 )
723731 yield mr
724732
725- async for record in self .output_pipeline_instance .process_stream (stream_iterator ()):
733+ async for record in self .output_pipeline_instance .process_stream (
734+ stream_iterator (), cleanup_sensitive = False
735+ ):
726736 chunk = record .model_dump_json (exclude_none = True , exclude_unset = True )
727737 sse_data = f"data: { chunk } \n \n " .encode ("utf-8" )
728738 chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
729739 self ._proxy_transport_write (chunk_size .encode ())
730740 self ._proxy_transport_write (sse_data )
731741 self ._proxy_transport_write (b"\r \n " )
732742
733- sse_data = b"data: [DONE]\n \n "
734- # Add chunk size for DONE message too
735- chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
736- self ._proxy_transport_write (chunk_size .encode ())
737- self ._proxy_transport_write (sse_data )
738- self ._proxy_transport_write (b"\r \n " )
739- # Now send the final zero chunk
740- self ._proxy_transport_write (b"0\r \n \r \n " )
743+ if self .finish_stream :
744+ self .finish_data ()
741745
742746 except asyncio .CancelledError :
743747 logger .debug ("Stream processing cancelled" )
@@ -746,12 +750,37 @@ async def stream_iterator():
746750 logger .error (f"Error processing stream: { e } " )
747751 finally :
748752 # Clean up
753+ self .stream_queue = None
749754 if self .processing_task and not self .processing_task .done ():
750755 self .processing_task .cancel ()
751- if self .proxy .context_tracking and self .proxy .context_tracking .sensitive :
752- self .proxy .context_tracking .sensitive .secure_cleanup ()
756+
757+ def finish_data (self ):
758+ logger .debug ("Finishing data stream" )
759+ sse_data = b"data: [DONE]\n \n "
760+ # Add chunk size for DONE message too
761+ chunk_size = hex (len (sse_data ))[2 :] + "\r \n "
762+ self ._proxy_transport_write (chunk_size .encode ())
763+ self ._proxy_transport_write (sse_data )
764+ self ._proxy_transport_write (b"\r \n " )
765+ # Now send the final zero chunk
766+ self ._proxy_transport_write (b"0\r \n \r \n " )
767+
768+ # For debugging only
769+ # print("===========START DATA SENT====================")
770+ # for data in self.data_sent:
771+ # print(data)
772+ # self.data_sent = []
773+ # print("===========START DATA SENT====================")
774+
775+ self .finish_stream = False
776+ self .headers_sent = False
753777
754778 def _process_chunk (self , chunk : bytes ):
779+ # For debugging only
780+ # print("===========START DATA RECVD====================")
781+ # print(chunk)
782+ # print("===========END DATA RECVD======================")
783+
755784 records = self .sse_processor .process_chunk (chunk )
756785
757786 for record in records :
@@ -763,14 +792,12 @@ def _process_chunk(self, chunk: bytes):
763792 self .stream_queue .put_nowait (record )
764793
765794 def _proxy_transport_write (self , data : bytes ):
795+ # For debugging only
796+ # self.data_sent.append(data)
766797 if not self .proxy .transport or self .proxy .transport .is_closing ():
767798 logger .error ("Proxy transport not available" )
768799 return
769800 self .proxy .transport .write (data )
770- # print("DEBUG =================================")
771- # print(data)
772- # print("DEBUG =================================")
773-
774801
775802 def data_received (self , data : bytes ) -> None :
776803 """Handle data received from target"""
@@ -788,7 +815,7 @@ def data_received(self, data: bytes) -> None:
788815 if header_end != - 1 :
789816 self .headers_sent = True
790817 # Send headers first
791- headers = data [: header_end ]
818+ headers = data [:header_end ]
792819
793820 # If Transfer-Encoding is not present, add it
794821 if b"Transfer-Encoding:" not in headers :
@@ -800,15 +827,13 @@ def data_received(self, data: bytes) -> None:
800827 logger .debug (f"Headers sent: { headers } " )
801828
802829 data = data [header_end + 4 :]
803- # print("DEBUG =================================")
804- # print(data)
805- # print("DEBUG =================================")
806830
807831 self ._process_chunk (data )
808832
809833 def connection_lost (self , exc : Optional [Exception ]) -> None :
810834 """Handle connection loss to target"""
811835
836+ logger .debug ("Lost connection to target" )
812837 if (
813838 not self .proxy ._closing
814839 and self .proxy .transport
0 commit comments