Skip to content

Commit 7aeb196

Browse files
committed
General Updates:
- [NEW] support cancellation of `ChatSession.InitializeSessionFromHistoryAsync` - [NEW] improve usage of `CancellationToken`s in `LlamaExecutorBase` - [FIX] `CS1998` warnings
1 parent de00c15 commit 7aeb196

File tree

6 files changed

+104
-83
lines changed

6 files changed

+104
-83
lines changed

LLama.Examples/Examples/QuantizeModel.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ namespace LLama.Examples.Examples
22
{
33
public class QuantizeModel
44
{
5-
public static async Task Run()
5+
public static Task Run()
66
{
77
string inputPath = UserSettings.GetModelPath();
88

@@ -20,6 +20,8 @@ public static async Task Run()
2020
{
2121
Console.WriteLine("Quantization failed!");
2222
}
23+
24+
return Task.CompletedTask;
2325
}
2426
}
2527
}

LLama/ChatSession.cs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ public class ChatSession
7676
/// <param name="executor">The executor for this session</param>
7777
/// <param name="history">History for this session</param>
7878
/// <param name="transform">History Transform for this session</param>
79+
/// <param name="cancellationToken">A token that cancels the operation</param>
7980
/// <returns>A new chat session.</returns>
8081
public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
81-
ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null)
82+
ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null, CancellationToken cancellationToken = default)
8283
{
8384
if (executor is not StatefulExecutorBase statefulExecutor)
8485
{
@@ -90,7 +91,7 @@ public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
9091
session = session.WithHistoryTransform(transform);
9192
}
9293

93-
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
94+
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
9495
return session;
9596
}
9697

@@ -311,13 +312,15 @@ public ChatSession RemoveLastMessage()
311312
/// Compute KV cache for the message and add it to the chat history.
312313
/// </summary>
313314
/// <param name="message"></param>
315+
/// <param name="cancellationToken"></param>
314316
/// <returns></returns>
315-
public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
317+
public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message, CancellationToken cancellationToken = default)
316318
{
317319
if (Executor is not StatefulExecutorBase statefulExecutor)
318320
{
319321
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
320322
}
323+
321324
AddMessage(message);
322325
var content = message.Content;
323326
if (message.AuthorRole != AuthorRole.Assistant)
@@ -328,27 +331,27 @@ public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
328331
}
329332
}
330333

331-
await statefulExecutor.PrefillPromptAsync(content);
334+
await statefulExecutor.PrefillPromptAsync(content, cancellationToken);
332335
return this;
333336
}
334337

335338
/// <summary>
336339
/// Compute KV cache for the system message and add it to the chat history.
337340
/// </summary>
338-
public Task<ChatSession> AddAndProcessSystemMessage(string content)
339-
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
341+
public Task<ChatSession> AddAndProcessSystemMessage(string content, CancellationToken cancellationToken = default)
342+
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content), cancellationToken);
340343

341344
/// <summary>
342345
/// Compute KV cache for the user message and add it to the chat history.
343346
/// </summary>
344-
public Task<ChatSession> AddAndProcessUserMessage(string content)
345-
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
347+
public Task<ChatSession> AddAndProcessUserMessage(string content, CancellationToken cancellationToken = default)
348+
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content), cancellationToken);
346349

347350
/// <summary>
348351
/// Compute KV cache for the assistant message and add it to the chat history.
349352
/// </summary>
350-
public Task<ChatSession> AddAndProcessAssistantMessage(string content)
351-
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
353+
public Task<ChatSession> AddAndProcessAssistantMessage(string content, CancellationToken cancellationToken = default)
354+
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content), cancellationToken);
352355

353356
/// <summary>
354357
/// Replace a user message with a new message and remove all messages after the new message.

LLama/LLamaContext.cs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
using LLama.Native;
21
using System;
32
using System.Collections.Generic;
43
using System.Diagnostics;
5-
using System.Text;
64
using System.IO;
75
using System.IO.MemoryMappedFiles;
6+
using System.Text;
7+
using System.Threading;
88
using System.Threading.Tasks;
99
using LLama.Abstractions;
10+
using LLama.Native;
1011
using Microsoft.Extensions.Logging;
11-
using System.Threading;
1212

1313
namespace LLama
1414
{
@@ -73,7 +73,7 @@ public int BatchThreads
7373
/// Get the special tokens for the model associated with this context
7474
/// </summary>
7575
public SafeLlamaModelHandle.Vocabulary Vocab { get; }
76-
76+
7777
/// <summary>
7878
/// Create a new LLamaContext for the given LLamaWeights
7979
/// </summary>
@@ -396,7 +396,7 @@ public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancel
396396
{
397397
return Task.Run(() => Decode(batch), cancellationToken);
398398
}
399-
399+
400400
/// <summary>
401401
/// </summary>
402402
/// <param name="batch"></param>
@@ -406,10 +406,10 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch)
406406
return 0;
407407
if (batch.EmbeddingsCount > BatchSize)
408408
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
409-
409+
410410
return (DecodeResult)NativeHandle.Decode(batch);
411411
}
412-
412+
413413
/// <summary>
414414
/// </summary>
415415
/// <param name="batch"></param>
@@ -425,15 +425,16 @@ public Task<DecodeResult> DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo
425425
/// <param name="id"></param>
426426
/// <param name="batch"></param>
427427
/// <param name="n_past"></param>
428+
/// <param name="cancellationToken"></param>
428429
/// <returns>A tuple, containing the decode result, the number of tokens that have <b>not</b> been decoded yet and the total number of tokens that have been decoded.</returns>
429-
public Task<(DecodeResult, int, int)> DecodeAsync(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, int n_past)
430+
public Task<(DecodeResult, int, int)> DecodeAsync(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, int n_past, CancellationToken cancellationToken = default)
430431
{
431432
return Task.Run(() =>
432433
{
433434
var past = n_past;
434435
var res = NativeHandle.Decode(tokens, id, batch, ref past);
435436
return (res.Item1, res.Item2, past);
436-
});
437+
}, cancellationToken);
437438
}
438439
#endregion
439440

LLama/LLamaExecutorBase.cs

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,36 +239,41 @@ protected virtual void TryReuseMatchingPrefix()
239239
/// Decide whether to continue the loop.
240240
/// </summary>
241241
/// <param name="args"></param>
242+
/// <param name="cancellationToken"></param>
242243
/// <returns></returns>
243-
protected abstract Task<bool> GetLoopCondition(InferStateArgs args);
244+
protected abstract Task<bool> GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default);
244245

245246
/// <summary>
246247
/// Preprocess the inputs before the inference.
247248
/// </summary>
248249
/// <param name="text"></param>
249250
/// <param name="args"></param>
250-
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
251+
/// <param name="cancellationToken"></param>
252+
protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default);
251253

252254
/// <summary>
253255
/// Do some post processing after the inference.
254256
/// </summary>
255257
/// <param name="inferenceParams"></param>
256258
/// <param name="args"></param>
259+
/// <param name="cancellationToken"></param>
257260
/// <returns></returns>
258-
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
261+
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);
259262

260263
/// <summary>
261264
/// The core inference logic.
262265
/// </summary>
263266
/// <param name="inferenceParams"></param>
264267
/// <param name="args"></param>
265-
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
268+
/// <param name="cancellationToken"></param>
269+
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);
266270

267271
/// <summary>
268272
/// Save the current state to a file.
269273
/// </summary>
270274
/// <param name="filename"></param>
271-
public abstract Task SaveState(string filename);
275+
/// <param name="cancellationToken"></param>
276+
public abstract Task SaveState(string filename, CancellationToken cancellationToken = default);
272277

273278
/// <summary>
274279
/// Get the current state data.
@@ -280,13 +285,15 @@ protected virtual void TryReuseMatchingPrefix()
280285
/// Load the state from data.
281286
/// </summary>
282287
/// <param name="data"></param>
283-
public abstract Task LoadState(ExecutorBaseState data);
288+
/// <param name="cancellationToken"></param>
289+
public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default);
284290

285291
/// <summary>
286292
/// Load the state from a file.
287293
/// </summary>
288294
/// <param name="filename"></param>
289-
public abstract Task LoadState(string filename);
295+
/// <param name="cancellationToken"></param>
296+
public abstract Task LoadState(string filename, CancellationToken cancellationToken = default);
290297

291298

292299
/// <summary>
@@ -310,23 +317,23 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
310317
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
311318
};
312319

313-
await PreprocessInputs(text, args);
320+
await PreprocessInputs(text, args, cancellationToken);
314321

315-
while (await GetLoopCondition(args))
322+
while (await GetLoopCondition(args, cancellationToken))
316323
{
317324
if (cancellationToken.IsCancellationRequested)
318325
{
319326
break;
320327
}
321-
await InferInternal(inferenceParams, args);
328+
await InferInternal(inferenceParams, args, cancellationToken);
322329

323330
if (args.ReturnValue)
324331
{
325332
_decoder.AddRange(_embeds);
326333
yield return _decoder.Read();
327334
}
328335

329-
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
336+
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args, cancellationToken);
330337
if (extraOutputs is { Count: > 0 })
331338
{
332339
foreach (var item in extraOutputs)
@@ -346,8 +353,9 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
346353
/// It could reduce the latency of the first time response if the first input from the user is not immediate.
347354
/// </summary>
348355
/// <param name="prompt">Prompt to process</param>
356+
/// <param name="cancellationToken"></param>
349357
/// <returns></returns>
350-
public virtual async Task PrefillPromptAsync(string prompt)
358+
public virtual async Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken = default)
351359
{
352360
var inferenceParams = new InferenceParams
353361
{
@@ -362,11 +370,11 @@ public virtual async Task PrefillPromptAsync(string prompt)
362370
NeedToSaveSession = false
363371
};
364372

365-
await PreprocessInputs(prompt, args);
373+
await PreprocessInputs(prompt, args, cancellationToken);
366374
// First run adds the prompt to the _embeds
367-
await InferInternal(inferenceParams, args);
375+
await InferInternal(inferenceParams, args, cancellationToken);
368376
// Second run puts it through decode
369-
await InferInternal(inferenceParams, args);
377+
await InferInternal(inferenceParams, args, cancellationToken);
370378
}
371379

372380
/// <summary>

LLama/LLamaInstructExecutor.cs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
using LLama.Abstractions;
2-
using LLama.Common;
3-
using LLama.Native;
41
using System;
52
using System.Collections.Generic;
63
using System.IO;
74
using System.Linq;
85
using System.Text.Json;
96
using System.Text.Json.Serialization;
7+
using System.Threading;
108
using System.Threading.Tasks;
9+
using LLama.Abstractions;
10+
using LLama.Common;
1111
using LLama.Exceptions;
12+
using LLama.Native;
1213
using LLama.Sampling;
1314
using Microsoft.Extensions.Logging;
1415

@@ -65,9 +66,9 @@ public override ExecutorBaseState GetStateData()
6566
return state;
6667
}
6768
/// <inheritdoc />
68-
public override Task LoadState(ExecutorBaseState data)
69+
public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default)
6970
{
70-
if(data is InstructExecutorState state)
71+
if (data is InstructExecutorState state)
7172
{
7273
_n_session_consumed = state.ConsumedSessionCount;
7374
_embed_inps = state.EmbedInps!.ToList();
@@ -91,34 +92,34 @@ public override Task LoadState(ExecutorBaseState data)
9192
}
9293

9394
/// <inheritdoc />
94-
public override async Task SaveState(string filename)
95+
public override async Task SaveState(string filename, CancellationToken cancellationToken = default)
9596
{
9697
var state = (InstructExecutorState)GetStateData();
9798
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
9899
{
99-
await JsonSerializer.SerializeAsync(fs, state);
100+
await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken);
100101
}
101102
}
102103
/// <inheritdoc />
103-
public override async Task LoadState(string filename)
104+
public override async Task LoadState(string filename, CancellationToken cancellationToken)
104105
{
105106
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
106107
{
107108
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
108-
await LoadState(state!);
109+
await LoadState(state!, cancellationToken);
109110
}
110111
}
111112

112113
/// <inheritdoc />
113-
protected override Task<bool> GetLoopCondition(InferStateArgs args)
114+
protected override Task<bool> GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken)
114115
{
115116
return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
116117
}
117118

118119
/// <inheritdoc />
119-
protected override Task PreprocessInputs(string? text, InferStateArgs args)
120+
protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken)
120121
{
121-
args.Antiprompts ??= [ ];
122+
args.Antiprompts ??= [];
122123
if (!args.Antiprompts.Contains(_instructionPrefix))
123124
args.Antiprompts.Add(_instructionPrefix);
124125

@@ -154,19 +155,19 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
154155
}
155156

156157
/// <inheritdoc />
157-
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
158+
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
158159
{
159160
if (_embed_inps.Count <= _consumedTokensCount)
160161
{
161162
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
162163
{
163164
args.WaitForInput = true;
164-
return (true, Array.Empty<string>());
165+
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, []));
165166
}
166167

167168
if (_pastTokensCount > 0 && args.WaitForInput)
168169
{
169-
return (true, new[] { "\n> " });
170+
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, ["\n> "]));
170171
}
171172
}
172173

@@ -180,11 +181,12 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
180181
args.RemainedTokens = inferenceParams.MaxTokens;
181182
args.WaitForInput = true;
182183
}
183-
return (false, Array.Empty<string>());
184+
185+
return Task.FromResult<(bool, IReadOnlyList<string>)>((false, []));
184186
}
185187

186188
/// <inheritdoc />
187-
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
189+
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
188190
{
189191
var batch = new LLamaBatch();
190192

0 commit comments

Comments
 (0)