@@ -239,36 +239,41 @@ protected virtual void TryReuseMatchingPrefix()
239239 /// Decide whether to continue the loop.
240240 /// </summary>
241241 /// <param name="args"></param>
242+ /// <param name="cancellationToken"></param>
242243 /// <returns></returns>
243- protected abstract Task < bool > GetLoopCondition ( InferStateArgs args ) ;
244+ protected abstract Task < bool > GetLoopCondition ( InferStateArgs args , CancellationToken cancellationToken = default ) ;
244245
245246 /// <summary>
246247 /// Preprocess the inputs before the inference.
247248 /// </summary>
248249 /// <param name="text"></param>
249250 /// <param name="args"></param>
250- protected abstract Task PreprocessInputs ( string ? text , InferStateArgs args ) ;
251+ /// <param name="cancellationToken"></param>
252+ protected abstract Task PreprocessInputs ( string ? text , InferStateArgs args , CancellationToken cancellationToken = default ) ;
251253
252254 /// <summary>
253255 /// Do some post processing after the inference.
254256 /// </summary>
255257 /// <param name="inferenceParams"></param>
256258 /// <param name="args"></param>
259+ /// <param name="cancellationToken"></param>
257260 /// <returns></returns>
258- protected abstract Task < ( bool , IReadOnlyList < string > ) > PostProcess ( IInferenceParams inferenceParams , InferStateArgs args ) ;
261+ protected abstract Task < ( bool , IReadOnlyList < string > ) > PostProcess ( IInferenceParams inferenceParams , InferStateArgs args , CancellationToken cancellationToken = default ) ;
259262
260263 /// <summary>
261264 /// The core inference logic.
262265 /// </summary>
263266 /// <param name="inferenceParams"></param>
264267 /// <param name="args"></param>
265- protected abstract Task InferInternal ( IInferenceParams inferenceParams , InferStateArgs args ) ;
268+ /// <param name="cancellationToken"></param>
269+ protected abstract Task InferInternal ( IInferenceParams inferenceParams , InferStateArgs args , CancellationToken cancellationToken = default ) ;
266270
267271 /// <summary>
268272 /// Save the current state to a file.
269273 /// </summary>
270274 /// <param name="filename"></param>
271- public abstract Task SaveState ( string filename ) ;
275+ /// <param name="cancellationToken"></param>
276+ public abstract Task SaveState ( string filename , CancellationToken cancellationToken = default ) ;
272277
273278 /// <summary>
274279 /// Get the current state data.
@@ -280,13 +285,15 @@ protected virtual void TryReuseMatchingPrefix()
280285 /// Load the state from data.
281286 /// </summary>
282287 /// <param name="data"></param>
283- public abstract Task LoadState ( ExecutorBaseState data ) ;
288+ /// <param name="cancellationToken"></param>
289+ public abstract Task LoadState ( ExecutorBaseState data , CancellationToken cancellationToken = default ) ;
284290
285291 /// <summary>
286292 /// Load the state from a file.
287293 /// </summary>
288294 /// <param name="filename"></param>
289- public abstract Task LoadState ( string filename ) ;
295+ /// <param name="cancellationToken"></param>
296+ public abstract Task LoadState ( string filename , CancellationToken cancellationToken = default ) ;
290297
291298
292299 /// <summary>
@@ -310,23 +317,23 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
310317 NeedToSaveSession = ! string . IsNullOrEmpty ( _pathSession ) && _n_matching_session_tokens < _embed_inps . Count
311318 } ;
312319
313- await PreprocessInputs ( text , args ) ;
320+ await PreprocessInputs ( text , args , cancellationToken ) ;
314321
315- while ( await GetLoopCondition ( args ) )
322+ while ( await GetLoopCondition ( args , cancellationToken ) )
316323 {
317324 if ( cancellationToken . IsCancellationRequested )
318325 {
319326 break ;
320327 }
321- await InferInternal ( inferenceParams , args ) ;
328+ await InferInternal ( inferenceParams , args , cancellationToken ) ;
322329
323330 if ( args . ReturnValue )
324331 {
325332 _decoder . AddRange ( _embeds ) ;
326333 yield return _decoder . Read ( ) ;
327334 }
328335
329- var ( breakGeneration , extraOutputs ) = await PostProcess ( inferenceParams , args ) ;
336+ var ( breakGeneration , extraOutputs ) = await PostProcess ( inferenceParams , args , cancellationToken ) ;
330337 if ( extraOutputs is { Count : > 0 } )
331338 {
332339 foreach ( var item in extraOutputs )
@@ -346,8 +353,9 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
346353 /// It could reduce the latency of the first time response if the first input from the user is not immediate.
347354 /// </summary>
348355 /// <param name="prompt">Prompt to process</param>
356+ /// <param name="cancellationToken"></param>
349357 /// <returns></returns>
350- public virtual async Task PrefillPromptAsync ( string prompt )
358+ public virtual async Task PrefillPromptAsync ( string prompt , CancellationToken cancellationToken = default )
351359 {
352360 var inferenceParams = new InferenceParams
353361 {
@@ -362,11 +370,11 @@ public virtual async Task PrefillPromptAsync(string prompt)
362370 NeedToSaveSession = false
363371 } ;
364372
365- await PreprocessInputs ( prompt , args ) ;
373+ await PreprocessInputs ( prompt , args , cancellationToken ) ;
366374 // First run adds the prompt to the _embeds
367- await InferInternal ( inferenceParams , args ) ;
375+ await InferInternal ( inferenceParams , args , cancellationToken ) ;
368376 // Second run puts it through decode
369- await InferInternal ( inferenceParams , args ) ;
377+ await InferInternal ( inferenceParams , args , cancellationToken ) ;
370378 }
371379
372380 /// <summary>
0 commit comments