1616 TEMPLATE_CHUNKS , TEMPLATE_NO_CHUNKS , TEMPLATE_MERGE ,
1717 TEMPLATE_CHUNKS_MD , TEMPLATE_NO_CHUNKS_MD , TEMPLATE_MERGE_MD
1818)
19+ from langchain .callbacks .manager import CallbackManager
20+ from langchain .callbacks import get_openai_callback
21+ from requests .exceptions import Timeout
22+ import time
1923
2024class GenerateAnswerNode (BaseNode ):
2125 """
@@ -56,6 +60,7 @@ def __init__(
5660 self .script_creator = node_config .get ("script_creator" , False )
5761 self .is_md_scraper = node_config .get ("is_md_scraper" , False )
5862 self .additional_info = node_config .get ("additional_info" )
63+ self .timeout = node_config .get ("timeout" , 30 )
5964
6065 def execute (self , state : dict ) -> dict :
6166 """
@@ -114,14 +119,33 @@ def execute(self, state: dict) -> dict:
114119 template_chunks_prompt = self .additional_info + template_chunks_prompt
115120 template_merge_prompt = self .additional_info + template_merge_prompt
116121
122+ def invoke_with_timeout (chain , inputs , timeout ):
123+ try :
124+ with get_openai_callback () as cb :
125+ start_time = time .time ()
126+ response = chain .invoke (inputs )
127+ if time .time () - start_time > timeout :
128+ raise Timeout (f"Response took longer than { timeout } seconds" )
129+ return response
130+ except Timeout as e :
131+ self .logger .error (f"Timeout error: { str (e )} " )
132+ raise
133+ except Exception as e :
134+ self .logger .error (f"Error during chain execution: { str (e )} " )
135+ raise
136+
117137 if len (doc ) == 1 :
118138 prompt = PromptTemplate (
119139 template = template_no_chunks_prompt ,
120140 input_variables = ["question" ],
121141 partial_variables = {"context" : doc , "format_instructions" : format_instructions }
122142 )
123143 chain = prompt | self .llm_model
124- raw_response = chain .invoke ({"question" : user_prompt })
144+ try :
145+ raw_response = invoke_with_timeout (chain , {"question" : user_prompt }, self .timeout )
146+ except Timeout :
147+ state .update ({self .output [0 ]: {"error" : "Response timeout exceeded" }})
148+ return state
125149
126150 if output_parser :
127151 try :
@@ -155,7 +179,15 @@ def execute(self, state: dict) -> dict:
155179 chains_dict [chain_name ] = chains_dict [chain_name ] | output_parser
156180
157181 async_runner = RunnableParallel (** chains_dict )
158- batch_results = async_runner .invoke ({"question" : user_prompt })
182+ try :
183+ batch_results = invoke_with_timeout (
184+ async_runner ,
185+ {"question" : user_prompt },
186+ self .timeout
187+ )
188+ except Timeout :
189+ state .update ({self .output [0 ]: {"error" : "Response timeout exceeded during chunk processing" }})
190+ return state
159191
160192 merge_prompt = PromptTemplate (
161193 template = template_merge_prompt ,
@@ -166,7 +198,15 @@ def execute(self, state: dict) -> dict:
166198 merge_chain = merge_prompt | self .llm_model
167199 if output_parser :
168200 merge_chain = merge_chain | output_parser
169- answer = merge_chain .invoke ({"context" : batch_results , "question" : user_prompt })
201+ try :
202+ answer = invoke_with_timeout (
203+ merge_chain ,
204+ {"context" : batch_results , "question" : user_prompt },
205+ self .timeout
206+ )
207+ except Timeout :
208+ state .update ({self .output [0 ]: {"error" : "Response timeout exceeded during merge" }})
209+ return state
170210
171211 state .update ({self .output [0 ]: answer })
172212 return state
0 commit comments