Skip to content

Commit ff6ea95

Browse files
committed
Fix Reranker and Sampling Test Failures
1 parent 424a736 commit ff6ea95

File tree

6 files changed

+18
-2
lines changed

6 files changed

+18
-2
lines changed

LLama.Unittest/LLamaRerankerTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
1818
var @params = new ModelParams(Constants.RerankingModelPath)
1919
{
2020
ContextSize = 0,
21+
SeqMax = 1,
2122
PoolingType = LLamaPoolingType.Rank,
2223
GpuLayerCount = Constants.CIGpuLayerCount,
23-
2424
};
2525
using var weights = LLamaWeights.LoadFromFile(@params);
2626
_reranker = new LLamaReranker(weights, @params);

LLama.Unittest/SamplingTests.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ public SamplingTests(ITestOutputHelper testOutputHelper)
2525
_params = new ModelParams(Constants.GenerativeModelPath2) {
2626
ContextSize = 200,
2727
BatchSize = 200,
28+
SeqMax = 4,
2829
GpuLayerCount = Constants.CIGpuLayerCount,
2930
};
3031
_model = LLamaWeights.LoadFromFile(_params);

LLama.Web/Common/ModelOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public class ModelOptions
102102
public bool NoKqvOffload { get; set; }
103103

104104
/// <inheritdoc />
105-
public bool FlashAttention { get; set; }
105+
public bool? FlashAttention { get; set; }
106106

107107
/// <inheritdoc />
108108
public Encoding Encoding { get; set; } = Encoding.UTF8;

LLama/Abstractions/IContextParams.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ public interface IContextParams
103103
/// </summary>
104104
bool NoKqvOffload { get; }
105105

106+
/// <summary>
107+
/// Whether to use flash attention
108+
/// </summary>
109+
bool? FlashAttention { get; }
110+
106111
/// <summary>
107112
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt;= 0 to disable (default)
108113
/// </summary>

LLama/Common/ModelParams.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ public record ModelParams
9696

9797
/// <inheritdoc />
9898
public bool NoKqvOffload { get; set; }
99+
100+
/// <inheritdoc />
101+
public bool? FlashAttention { get; set; }
99102

100103
/// <inheritdoc />
101104
[Obsolete]

LLama/Extensions/IContextParamsExtensions.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
5151
result.offload_kqv = !@params.NoKqvOffload;
5252
result.llama_pooling_type = @params.PoolingType;
5353
result.attention_type = @params.AttentionType;
54+
result.llama_flash_attn_type = @params.FlashAttention switch
55+
{
56+
true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED,
57+
false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED,
58+
null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO
59+
};
60+
result.kv_unified = true;
5461

5562
result.n_threads = Threads(@params.Threads);
5663
result.n_threads_batch = Threads(@params.BatchThreads);

0 commit comments

Comments
 (0)