Skip to content

Commit 4355809

Browse files
committed
Enable FlashAttention and remove SeqMax param
FlashAttention is now enabled by default in model parameter initialization for embedding and text generation. The unused SeqMax parameter has been removed from unit tests to simplify configuration. Minor formatting improvements were made in IContextParamsExtensions and NativeApi for consistency.
1 parent 0990be3 commit 4355809

File tree

7 files changed

+8
-10
lines changed

7 files changed

+8
-10
lines changed

LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
4040
SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer,
4141
BatchSize = 512,
4242
UBatchSize = 512,
43+
FlashAttention = true,
4344
UseMemorymap = true,
4445
PoolingType = LLamaPoolingType.Mean,
4546
};
@@ -67,6 +68,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we
6768
SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer,
6869
BatchSize = 512,
6970
UBatchSize = 512,
71+
FlashAttention = true,
7072
UseMemorymap = true,
7173
PoolingType = LLamaPoolingType.Mean,
7274
};

LLama.KernelMemory/LlamaSharpTextGenerator.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public LlamaSharpTextGenerator(LLamaSharpConfig config)
3838
SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer,
3939
BatchSize = 512,
4040
UBatchSize = 512,
41+
FlashAttention = true,
4142
UseMemorymap = true
4243
};
4344
_weights = LLamaWeights.LoadFromFile(@params);
@@ -65,6 +66,7 @@ public LlamaSharpTextGenerator(LLamaWeights weights, LLamaSharpConfig config, St
6566
SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer,
6667
BatchSize = 512,
6768
UBatchSize = 512,
69+
FlashAttention = true,
6870
UseMemorymap = true
6971
};
7072
_executor = executor ?? new StatelessExecutor(_weights, @params);

LLama.Unittest/LLamaContextTests.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ public LLamaContextTests()
1616
ContextSize = 512,
1717
BatchSize = 8,
1818
UBatchSize = 8,
19-
SeqMax = 1,
2019
VocabOnly = false,
2120
GpuLayerCount = Constants.CIGpuLayerCount,
2221
};

LLama.Unittest/LLamaRerankerTests.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
1818
var @params = new ModelParams(Constants.RerankingModelPath)
1919
{
2020
ContextSize = 0,
21-
SeqMax = 1,
2221
PoolingType = LLamaPoolingType.Rank,
2322
GpuLayerCount = Constants.CIGpuLayerCount,
2423
};

LLama.Unittest/SamplingTests.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ public SamplingTests(ITestOutputHelper testOutputHelper)
2525
_params = new ModelParams(Constants.GenerativeModelPath2) {
2626
ContextSize = 200,
2727
BatchSize = 200,
28-
SeqMax = 4,
2928
GpuLayerCount = Constants.CIGpuLayerCount,
3029
};
3130
_model = LLamaWeights.LoadFromFile(_params);

LLama/Extensions/IContextParamsExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
3737
result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
3838
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
3939
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified;
40-
40+
4141
result.defrag_threshold = @params.DefragThreshold ?? -1;
4242

4343
result.cb_eval = IntPtr.Zero;

LLama/Native/NativeApi.cs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,12 @@ public static void llama_empty_call()
175175
/// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param>
176176
/// <param name="length">The size of the allocated buffer</param>
177177
/// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns>
178-
public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg,
179-
[MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length)
178+
public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length)
180179
{
181180
return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length);
182181

183-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl,
184-
EntryPoint = "llama_chat_apply_template")]
185-
static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg,
186-
[MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
182+
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl,EntryPoint = "llama_chat_apply_template")]
183+
static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
187184
}
188185

189186
/// <summary>

0 commit comments

Comments
 (0)