@@ -21,7 +21,7 @@ namespace LLama
2121 public class InteractiveExecutor : StatefulExecutorBase
2222 {
2323 private bool _is_prompt_run = true ;
24-
24+
2525 // LLava
2626 private int _EmbedImagePosition = - 1 ;
2727 private List < SafeLlavaImageEmbedHandle > _imageEmbedHandles = new List < SafeLlavaImageEmbedHandle > ( ) ;
@@ -36,7 +36,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
3636 : base ( context , logger )
3737 {
3838 }
39-
39+
4040 /// <summary>
4141 ///
4242 /// </summary>
@@ -46,7 +46,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
4646 public InteractiveExecutor ( LLamaContext context , LLavaWeights clipModel , ILogger ? logger = null )
4747 : base ( context , clipModel , logger )
4848 {
49- }
49+ }
5050
5151 /// <inheritdoc />
5252 public override ExecutorBaseState GetStateData ( )
@@ -67,6 +67,7 @@ public override ExecutorBaseState GetStateData()
6767 } ;
6868 return state ;
6969 }
70+
7071 /// <inheritdoc />
7172 public override Task LoadState ( ExecutorBaseState data )
7273 {
@@ -88,23 +89,23 @@ public override Task LoadState(ExecutorBaseState data)
8889
8990 return Task . CompletedTask ;
9091 }
92+
9193 /// <inheritdoc />
9294 public override async Task SaveState ( string filename )
9395 {
9496 var state = ( InteractiveExecutorState ) GetStateData ( ) ;
95- using ( var fs = new FileStream ( filename , FileMode . Create , FileAccess . Write ) )
97+ using ( var fs = new FileStream ( filename , FileMode . Create , FileAccess . Write ) )
9698 {
9799 await JsonSerializer . SerializeAsync ( fs , state ) ;
98100 }
99101 }
102+
100103 /// <inheritdoc />
101104 public override async Task LoadState ( string filename )
102105 {
103- using ( var fs = new FileStream ( filename , FileMode . Open , FileAccess . Read ) )
104- {
105- var state = await JsonSerializer . DeserializeAsync < InteractiveExecutorState > ( fs ) ;
106- await LoadState ( state ! ) ;
107- }
106+ using var fs = new FileStream ( filename , FileMode . Open , FileAccess . Read ) ;
107+ var state = await JsonSerializer . DeserializeAsync < InteractiveExecutorState > ( fs ) ;
108+ await LoadState ( state ! ) ;
108109 }
109110
110111 /// <summary>
@@ -122,7 +123,11 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
122123 if ( _is_prompt_run )
123124 {
124125 // When running the first input (prompt) in interactive mode, we should specially process it.
125- if ( text == null ) throw new ArgumentException ( "Prompt cannot be null to trigger continuation if a prompt has not been provided previously." ) ;
126+ if ( text == null )
127+ {
128+ throw new ArgumentException ( "Prompt cannot be null to trigger continuation if a prompt has not been provided previously." ) ;
129+ }
130+
126131 if ( ! IsMultiModal )
127132 {
128133 _embed_inps = Context . Tokenize ( text , true , true ) . ToList ( ) ;
@@ -159,8 +164,8 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
159164 }
160165
161166 /// <inheritdoc />
162- private Task PreprocessLlava ( string text , InferStateArgs args , bool addBos = true )
163- {
167+ private Task PreprocessLlava ( string text , InferStateArgs args , bool addBos = true )
168+ {
164169 // If the prompt contains the tag <image> extract this.
165170 _imageInPrompt = text . Contains ( "<image>" ) ;
166171 if ( _imageInPrompt && IsMultiModal )
@@ -191,7 +196,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
191196 {
192197 var line_inp = Context . Tokenize ( text , false , true ) ;
193198 _embed_inps . AddRange ( line_inp ) ;
194- args . RemainedTokens -= line_inp . Length ;
199+ args . RemainedTokens -= line_inp . Length ;
195200 }
196201 }
197202 return Task . CompletedTask ;
@@ -203,20 +208,24 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
203208 /// <param name="inferenceParams"></param>
204209 /// <param name="args"></param>
205210 /// <returns></returns>
206- protected override async Task < ( bool , IReadOnlyList < string > ) > PostProcess ( IInferenceParams inferenceParams , InferStateArgs args )
211+ protected override Task < ( bool , IReadOnlyList < string > ) > PostProcess ( IInferenceParams inferenceParams , InferStateArgs args )
207212 {
208213 if ( _embed_inps . Count <= _consumedTokensCount )
209214 {
210215 if ( ! string . IsNullOrEmpty ( args . LastOutput ) && AntipromptProcessor . Add ( args . LastOutput ) )
216+ {
211217 args . WaitForInput = true ;
218+ }
212219
213220 if ( _pastTokensCount > 0 && args . WaitForInput )
214- return ( true , Array . Empty < string > ( ) ) ;
221+ {
222+ return Task . FromResult ( ( true , ( IReadOnlyList < string > ) [ ] ) ) ;
223+ }
215224 }
216225
217226 if ( _embeds . Count > 0 && _embeds . Last ( ) . IsEndOfGeneration ( Context . Vocab ) )
218227 {
219- return ( true , Array . Empty < string > ( ) ) ;
228+ return Task . FromResult ( ( true , ( IReadOnlyList < string > ) [ ] ) ) ;
220229 }
221230
222231 if ( args . RemainedTokens <= 0 && inferenceParams . MaxTokens != - 1 )
@@ -225,7 +234,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
225234 args . WaitForInput = true ;
226235 }
227236
228- return ( false , Array . Empty < string > ( ) ) ;
237+ return Task . FromResult ( ( true , ( IReadOnlyList < string > ) [ ] ) ) ;
229238 }
230239
231240 /// <inheritdoc />
@@ -258,18 +267,18 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
258267 // Changes to support Multi-Modal LLMs.
259268 //
260269 ( DecodeResult , int , int ) header , end , result ;
261- if ( IsMultiModal && _EmbedImagePosition > 0 )
270+ if ( IsMultiModal && _EmbedImagePosition > 0 )
262271 {
263272 // Tokens previous to the images
264273 header = await Context . DecodeAsync ( _embeds . GetRange ( 0 , _EmbedImagePosition ) , LLamaSeqId . Zero , batch , _pastTokensCount ) ;
265274 _pastTokensCount = header . Item3 ;
266275
267276 if ( header . Item1 != DecodeResult . Ok ) throw new LLamaDecodeError ( header . Item1 ) ;
268-
277+
269278 // Images
270- foreach ( var image in _imageEmbedHandles )
279+ foreach ( var image in _imageEmbedHandles )
271280 ClipModel ! . EvalImageEmbed ( Context , image , ref _pastTokensCount ) ;
272-
281+
273282 // Post-image Tokens
274283 end = await Context . DecodeAsync ( _embeds . GetRange ( _EmbedImagePosition , _embeds . Count - _EmbedImagePosition ) , LLamaSeqId . Zero , batch , _pastTokensCount ) ;
275284 _pastTokensCount = end . Item3 ;
@@ -285,7 +294,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
285294
286295 if ( result . Item1 != DecodeResult . Ok ) throw new LLamaDecodeError ( result . Item1 ) ;
287296 }
288-
297+
289298
290299 if ( _embeds . Count > 0 && ! string . IsNullOrEmpty ( _pathSession ) )
291300 {
0 commit comments