Skip to content

Commit 10486d5

Browse files
committed
Make PostProcess async again.
1 parent 79b5bb9 commit 10486d5

File tree

3 files changed

+24
-24
lines changed

3 files changed

+24
-24
lines changed

LLama/LLamaExecutorBase.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ protected virtual void TryReuseMatchingPrefix()
262262
/// <param name="inferenceParams"></param>
263263
/// <param name="args"></param>
264264
/// <returns></returns>
265-
protected abstract (bool, IReadOnlyList<string>) PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
265+
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
266266

267267
/// <summary>
268268
/// The core inference logic.
@@ -317,7 +317,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
317317
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
318318
};
319319

320-
AntipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts ?? Array.Empty<string>());
320+
AntipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts ?? []);
321321

322322
await PreprocessInputs(text, args);
323323

@@ -338,7 +338,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
338338
yield return decoded;
339339
}
340340

341-
var (breakGeneration, extraOutputs) = PostProcess(inferenceParams, args);
341+
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
342342
if (extraOutputs is { Count: > 0 })
343343
{
344344
foreach (var item in extraOutputs)

LLama/LLamaInstructExecutor.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,19 +155,19 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
155155
}
156156

157157
/// <inheritdoc />
158-
protected override (bool, IReadOnlyList<string>) PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
158+
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
159159
{
160160
if (_embed_inps.Count <= _consumedTokensCount)
161161
{
162162
if (!string.IsNullOrEmpty(args.LastOutput) && AntipromptProcessor.Add(args.LastOutput))
163163
{
164164
args.WaitForInput = true;
165-
return (true, Array.Empty<string>());
165+
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, []));
166166
}
167167

168168
if (_pastTokensCount > 0 && args.WaitForInput)
169169
{
170-
return (true, new[] { "\n> " });
170+
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, [ "\n> " ]));
171171
}
172172
}
173173

@@ -181,7 +181,7 @@ protected override (bool, IReadOnlyList<string>) PostProcess(IInferenceParams in
181181
args.RemainedTokens = inferenceParams.MaxTokens;
182182
args.WaitForInput = true;
183183
}
184-
return (false, Array.Empty<string>());
184+
return Task.FromResult<(bool, IReadOnlyList<string>)>((false, []));
185185
}
186186

187187
/// <inheritdoc />

LLama/LLamaInteractExecutor.cs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace LLama
2121
public class InteractiveExecutor : StatefulExecutorBase
2222
{
2323
private bool _is_prompt_run = true;
24-
24+
2525
// LLava
2626
private int _EmbedImagePosition = -1;
2727
private List<SafeLlavaImageEmbedHandle> _imageEmbedHandles = new List<SafeLlavaImageEmbedHandle>();
@@ -36,7 +36,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
3636
: base(context, logger)
3737
{
3838
}
39-
39+
4040
/// <summary>
4141
///
4242
/// </summary>
@@ -46,7 +46,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
4646
public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null)
4747
: base(context, clipModel, logger)
4848
{
49-
}
49+
}
5050

5151
/// <inheritdoc />
5252
public override ExecutorBaseState GetStateData()
@@ -89,7 +89,7 @@ public override Task LoadState(ExecutorBaseState data)
8989

9090
return Task.CompletedTask;
9191
}
92-
92+
9393
/// <inheritdoc />
9494
public override async Task SaveState(string filename)
9595
{
@@ -127,7 +127,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
127127
{
128128
throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
129129
}
130-
130+
131131
if (!IsMultiModal)
132132
{
133133
_embed_inps = Context.Tokenize(text, true, true).ToList();
@@ -164,8 +164,8 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
164164
}
165165

166166
/// <inheritdoc />
167-
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
168-
{
167+
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true)
168+
{
169169
// If the prompt contains the tag <image> extract this.
170170
_imageInPrompt = text.Contains("<image>");
171171
if (_imageInPrompt && IsMultiModal)
@@ -196,7 +196,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
196196
{
197197
var line_inp = Context.Tokenize(text, false, true);
198198
_embed_inps.AddRange(line_inp);
199-
args.RemainedTokens -= line_inp.Length;
199+
args.RemainedTokens -= line_inp.Length;
200200
}
201201
}
202202
return Task.CompletedTask;
@@ -208,7 +208,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
208208
/// <param name="inferenceParams"></param>
209209
/// <param name="args"></param>
210210
/// <returns></returns>
211-
protected override (bool, IReadOnlyList<string>) PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
211+
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
212212
{
213213
if (_embed_inps.Count <= _consumedTokensCount)
214214
{
@@ -219,13 +219,13 @@ protected override (bool, IReadOnlyList<string>) PostProcess(IInferenceParams in
219219

220220
if (_pastTokensCount > 0 && args.WaitForInput)
221221
{
222-
return (true, Array.Empty<string>());
222+
return Task.FromResult((true, (IReadOnlyList<string>)[]));
223223
}
224224
}
225225

226226
if (_embeds.Count > 0 && _embeds.Last().IsEndOfGeneration(Context.Vocab))
227227
{
228-
return (true, Array.Empty<string>());
228+
return Task.FromResult((true, (IReadOnlyList<string>)[]));
229229
}
230230

231231
if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
@@ -234,7 +234,7 @@ protected override (bool, IReadOnlyList<string>) PostProcess(IInferenceParams in
234234
args.WaitForInput = true;
235235
}
236236

237-
return (false, Array.Empty<string>());
237+
return Task.FromResult((true, (IReadOnlyList<string>)[]));
238238
}
239239

240240
/// <inheritdoc />
@@ -267,18 +267,18 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
267267
// Changes to support Multi-Modal LLMs.
268268
//
269269
(DecodeResult, int, int) header, end, result;
270-
if (IsMultiModal && _EmbedImagePosition > 0)
270+
if (IsMultiModal && _EmbedImagePosition > 0)
271271
{
272272
// Tokens previous to the images
273273
header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
274274
_pastTokensCount = header.Item3;
275275

276276
if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);
277-
277+
278278
// Images
279-
foreach( var image in _imageEmbedHandles )
279+
foreach (var image in _imageEmbedHandles)
280280
ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount);
281-
281+
282282
// Post-image Tokens
283283
end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
284284
_pastTokensCount = end.Item3;
@@ -294,7 +294,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
294294

295295
if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1);
296296
}
297-
297+
298298

299299
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
300300
{

0 commit comments

Comments
 (0)