diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index e68cabe..60585bd 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -45,6 +45,11 @@ jobs:
- name: Test
run: dotnet test BitNet-b1.58-Sharp.slnx --configuration Release --no-build --no-restore --filter "Category!=SlowLane"
+ - name: Generate default model weights
+ run: |
+ mkdir -p "${{ github.workspace }}/src/BitNetSharp.Core/Data/Models"
+ dotnet run --framework net9.0 --project "${{ github.workspace }}/src/BitNetSharp.App/BitNetSharp.App.csproj" --configuration Release --no-build -- export --output="${{ github.workspace }}/src/BitNetSharp.Core/Data/Models/bitnet-b1.58-default.gguf"
+
- name: Pack BitNetSharp.Core
run: dotnet pack "${{ github.workspace }}/src/BitNetSharp.Core/BitNetSharp.Core.csproj" --configuration Release --no-build --no-restore -p:PackageVersion=${{ steps.gitversion.outputs.semVer }} --output "${{ github.workspace }}/artifacts/packages/core"
diff --git a/.gitignore b/.gitignore
index 855a6ef..e8b65a7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -420,3 +420,6 @@ FodyWeavers.xsd
AGENTS-README-FIRST.yaml
.mcpServer/
+
+# Generated model weights (produced during CI builds)
+*.gguf
diff --git a/.version b/.version
index 6339605..a918a2a 100644
--- a/.version
+++ b/.version
@@ -1 +1 @@
-0.1.0-78
+0.6.0
diff --git a/Directory.Build.props b/Directory.Build.props
index 2addca1..2d12427 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -1,5 +1,5 @@
- 0.1.0
+ 0.6.0
diff --git a/GitVersion.yml b/GitVersion.yml
index 3f5c2ea..2a18fb4 100644
--- a/GitVersion.yml
+++ b/GitVersion.yml
@@ -1,4 +1,4 @@
-next-version: 0.1.0
+next-version: 0.6.0
mode: ContinuousDelivery
tag-prefix: '[vV]?'
strategies:
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index 8ddfd10..ab349d7 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -13,14 +13,21 @@ pr:
- "*"
variables:
- BuildConfiguration: Release
- SolutionFile: BitNet-b1.58-Sharp.slnx
- CoreProject: src/BitNetSharp.Core/BitNetSharp.Core.csproj
- AppProject: src/BitNetSharp.App/BitNetSharp.App.csproj
- CoreArtifactName: BitNetSharp.Core-nuget
- ToolArtifactName: BitNetSharp.App-dotnet-tool
- NuGetFeedUrl: ""
- NuGetApiKey: ""
+ - group: McpServer
+ - name: BuildConfiguration
+ value: Release
+ - name: SolutionFile
+ value: BitNet-b1.58-Sharp.slnx
+ - name: CoreProject
+ value: src/BitNetSharp.Core/BitNetSharp.Core.csproj
+ - name: AppProject
+ value: src/BitNetSharp.App/BitNetSharp.App.csproj
+ - name: CoreArtifactName
+ value: BitNetSharp.Core-nuget
+ - name: ToolArtifactName
+ value: BitNetSharp.App-dotnet-tool
+ - name: NuGetFeedUrl
+ value: "https://nuget.pkg.github.com/sharpninja/index.json"
stages:
- stage: build
@@ -70,6 +77,12 @@ stages:
- script: dotnet test "$(SolutionFile)" --configuration "$(BuildConfiguration)" --no-build --no-restore --filter "Category!=SlowLane"
displayName: Test
+ - pwsh: |
+ $modelsDirectory = "$(Build.SourcesDirectory)/src/BitNetSharp.Core/Data/Models"
+ New-Item -ItemType Directory -Force -Path $modelsDirectory | Out-Null
+ dotnet run --framework net9.0 --project "$(Build.SourcesDirectory)/$(AppProject)" --configuration "$(BuildConfiguration)" --no-build -- export --output="$modelsDirectory/bitnet-b1.58-default.gguf"
+ displayName: Generate default model weights
+
- pwsh: |
$outputDirectory = "$(Build.ArtifactStagingDirectory)/packages/core"
New-Item -ItemType Directory -Force -Path $outputDirectory | Out-Null
@@ -155,16 +168,9 @@ stages:
foreach ($package in $packages)
{
- if ([string]::IsNullOrWhiteSpace($env:NUGET_API_KEY))
- {
- dotnet nuget push $package.FullName --source $env:NUGET_SOURCE --api-key AzureArtifacts --skip-duplicate
- }
- else
- {
- dotnet nuget push $package.FullName --source $env:NUGET_SOURCE --api-key $env:NUGET_API_KEY --skip-duplicate
- }
+ dotnet nuget push $package.FullName --source $env:NUGET_SOURCE --api-key $env:NUGET_API_KEY --skip-duplicate
}
displayName: Publish packages to NuGet feed
env:
NUGET_SOURCE: $(NuGetFeedUrl)
- NUGET_API_KEY: $(NuGetApiKey)
+ NUGET_API_KEY: $(GH_TOKEN)
diff --git a/nuget.config b/nuget.config
new file mode 100644
index 0000000..b2dfa8d
--- /dev/null
+++ b/nuget.config
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/BitNetSharp.Core/BitNetOptions.cs b/src/BitNetSharp.Core/BitNetOptions.cs
index ae3322f..5008931 100644
--- a/src/BitNetSharp.Core/BitNetOptions.cs
+++ b/src/BitNetSharp.Core/BitNetOptions.cs
@@ -7,4 +7,5 @@ public sealed record BitNetOptions(
string PrimaryLanguage = "en-US",
bool EnableChainBuckets = false,
bool EnableSequenceCompression = false,
- double ChainBucketAcceptanceThreshold = 0.85d);
+ double ChainBucketAcceptanceThreshold = 0.85d,
+ bool EnableRecallHeatMap = true);
diff --git a/src/BitNetSharp.Core/BitNetPaperCheckpoint.cs b/src/BitNetSharp.Core/BitNetPaperCheckpoint.cs
index 783c5bd..002cf89 100644
--- a/src/BitNetSharp.Core/BitNetPaperCheckpoint.cs
+++ b/src/BitNetSharp.Core/BitNetPaperCheckpoint.cs
@@ -25,12 +25,14 @@ internal sealed record BitNetPaperCheckpointDocument(
string PrimaryLanguage,
bool EnableChainBuckets,
bool EnableSequenceCompression,
- double ChainBucketAcceptanceThreshold);
+ double ChainBucketAcceptanceThreshold,
+ bool EnableRecallHeatMap = true);
public static class BitNetPaperCheckpoint
{
private const string FormatName = "bitnet-b1.58-sharp.repository-checkpoint.v1";
private const string BucketSidecarFileName = "chain-buckets.bin";
+ private const string HeatMapSidecarFileName = "recall-heatmap.bin";
public static void Save(BitNetPaperModel model, string path)
{
@@ -59,9 +61,11 @@ public static void Save(BitNetPaperModel model, string path)
snapshot.PrimaryLanguage,
snapshot.EnableChainBuckets,
snapshot.EnableSequenceCompression,
- snapshot.ChainBucketAcceptanceThreshold);
+ snapshot.ChainBucketAcceptanceThreshold,
+ model.Options.EnableRecallHeatMap);
File.WriteAllText(path, JsonSerializer.Serialize(document, new JsonSerializerOptions { WriteIndented = true }));
SaveBucketSidecar(model.BucketTable, GetBucketSidecarPath(path));
+ SaveHeatMapSidecar(model.RecallHeatMap, GetHeatMapSidecarPath(path));
}
public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = VerbosityLevel.Normal)
@@ -86,7 +90,8 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb
document.PrimaryLanguage,
document.EnableChainBuckets,
document.EnableSequenceCompression,
- acceptanceThreshold),
+ acceptanceThreshold,
+ document.EnableRecallHeatMap),
document.Config,
document.BootstrapSeed);
var baselineSnapshot = BitNetPaperModelSnapshot.Capture(baselineModel);
@@ -119,6 +124,12 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb
model.LoadBucketTable(LoadBucketSidecar(bucketSidecarPath));
}
+ var heatMapSidecarPath = GetHeatMapSidecarPath(path);
+ if (model.RecallHeatMap is not null && File.Exists(heatMapSidecarPath))
+ {
+ model.RecallHeatMap.MergeFrom(BucketRecallHeatMapSerializer.Load(heatMapSidecarPath));
+ }
+
return model;
}
@@ -175,6 +186,29 @@ private static void SaveBucketSidecar(ChainBucketTable? bucketTable, string path
private static ChainBucketTable LoadBucketSidecar(string path) => ChainBucketTableBinarySerializer.Load(path);
+ private static string GetHeatMapSidecarPath(string checkpointPath)
+ {
+ var directory = Path.GetDirectoryName(checkpointPath);
+ return string.IsNullOrWhiteSpace(directory)
+ ? HeatMapSidecarFileName
+ : Path.Combine(directory, HeatMapSidecarFileName);
+ }
+
+ private static void SaveHeatMapSidecar(BucketRecallHeatMap? heatMap, string path)
+ {
+ if (heatMap is null)
+ {
+ if (File.Exists(path))
+ {
+ File.Delete(path);
+ }
+
+ return;
+ }
+
+ BucketRecallHeatMapSerializer.Save(heatMap, path);
+ }
+
private static float[][] ToJagged(float[,] matrix)
{
var result = new float[matrix.GetLength(0)][];
diff --git a/src/BitNetSharp.Core/BitNetPaperGguf.cs b/src/BitNetSharp.Core/BitNetPaperGguf.cs
index 916868e..2b0048e 100644
--- a/src/BitNetSharp.Core/BitNetPaperGguf.cs
+++ b/src/BitNetSharp.Core/BitNetPaperGguf.cs
@@ -25,6 +25,7 @@ public static void Save(BitNetPaperModel model, string path)
GgufWriter.Write(path, CreateMetadata(snapshot), CreateTensors(snapshot));
SaveBucketSidecar(model.BucketTable, GetBucketSidecarPath(path));
+ SaveHeatMapSidecar(model.RecallHeatMap, GetHeatMapSidecarPath(path));
}
public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = VerbosityLevel.Normal)
@@ -78,7 +79,8 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb
transformerProjectionWeights,
normScales,
ReadMatrix(tensors[OutputTensorName], config.VocabSize, config.Dimension),
- DeserializeMemorizedResponses(GetRequiredString(document.Metadata, MemorizedResponsesMetadataKey)));
+ DeserializeMemorizedResponses(GetRequiredString(document.Metadata, MemorizedResponsesMetadataKey)),
+ ReadOptionalBool(document.Metadata, "bitnetsharp.enable_recall_heat_map", defaultValue: true));
var model = snapshot.Restore(verbosity);
var bucketSidecarPath = GetBucketSidecarPath(path);
@@ -87,6 +89,12 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb
model.LoadBucketTable(ChainBucketTableBinarySerializer.Load(bucketSidecarPath));
}
+ var heatMapSidecarPath = GetHeatMapSidecarPath(path);
+ if (model.RecallHeatMap is not null && File.Exists(heatMapSidecarPath))
+ {
+ model.RecallHeatMap.MergeFrom(BucketRecallHeatMapSerializer.Load(heatMapSidecarPath));
+ }
+
return model;
}
@@ -106,6 +114,7 @@ private static Dictionary CreateMetadata(BitNetPaperModelSnapsho
["bitnetsharp.primary_language"] = snapshot.PrimaryLanguage,
["bitnetsharp.enable_chain_buckets"] = snapshot.EnableChainBuckets,
["bitnetsharp.enable_sequence_compression"] = snapshot.EnableSequenceCompression,
+ ["bitnetsharp.enable_recall_heat_map"] = snapshot.EnableRecallHeatMap,
["bitnetsharp.chain_bucket_acceptance_threshold"] = snapshot.ChainBucketAcceptanceThreshold,
["bitnetsharp.config.vocab_size"] = snapshot.Config.VocabSize,
["bitnetsharp.config.dimension"] = snapshot.Config.Dimension,
@@ -310,6 +319,31 @@ private static float[] FlattenMatrix(float[,] matrix)
private static string GetFeedForwardProjectionTensorName(int layer, string suffix) => $"blk.{layer}.ffn_{suffix}.weight";
+ private static string GetHeatMapSidecarPath(string ggufPath)
+ {
+ var directory = Path.GetDirectoryName(ggufPath);
+ var baseName = Path.GetFileNameWithoutExtension(ggufPath);
+ var fileName = $"{baseName}.recall-heatmap.bin";
+ return string.IsNullOrWhiteSpace(directory)
+ ? fileName
+ : Path.Combine(directory, fileName);
+ }
+
+ private static void SaveHeatMapSidecar(BucketRecallHeatMap? heatMap, string path)
+ {
+ if (heatMap is null)
+ {
+ if (File.Exists(path))
+ {
+ File.Delete(path);
+ }
+
+ return;
+ }
+
+ BucketRecallHeatMapSerializer.Save(heatMap, path);
+ }
+
private static string GetBucketSidecarPath(string ggufPath)
{
var directory = Path.GetDirectoryName(ggufPath);
@@ -354,6 +388,16 @@ private static string GetRequiredString(IReadOnlyDictionary meta
return text;
}
+ private static bool ReadOptionalBool(IReadOnlyDictionary metadata, string key, bool defaultValue)
+ {
+ if (metadata.TryGetValue(key, out var value) && value is bool boolean)
+ {
+ return boolean;
+ }
+
+ return defaultValue;
+ }
+
private static bool GetRequiredBool(IReadOnlyDictionary metadata, string key)
{
if (!metadata.TryGetValue(key, out var value) || value is not bool boolean)
diff --git a/src/BitNetSharp.Core/BitNetPaperModel.cs b/src/BitNetSharp.Core/BitNetPaperModel.cs
index 151d19d..4aff55f 100644
--- a/src/BitNetSharp.Core/BitNetPaperModel.cs
+++ b/src/BitNetSharp.Core/BitNetPaperModel.cs
@@ -24,6 +24,7 @@ public sealed class BitNetPaperModel
private readonly string[] _idToToken;
private readonly BitNetTokenizer _tokenizer;
private readonly object _gate = new();
+ private BucketRecallHeatMap? _recallHeatMap;
public BitNetPaperModel(IEnumerable trainingExamples, VerbosityLevel verbosity = VerbosityLevel.Normal, BitNetConfig? config = null, int seed = 42)
: this(
@@ -102,12 +103,21 @@ .. options.Vocabulary
///
public ChainBucketTable? BucketTable { get; private set; }
+ ///
+ /// Optional recall heat map that tracks per-token and per-chain accept/attempt counts
+ /// during speculative decoding. Populated when a bucket table is loaded and
+ /// is set.
+ ///
+ public BucketRecallHeatMap? RecallHeatMap => _recallHeatMap;
+
public string ModelId => "bitnet-b1.58-sharp";
public BitNetTokenizer Tokenizer => _tokenizer;
public long EstimateResidentParameterBytes() => Transformer.EstimateResidentParameterBytes();
+ public string GetTokenString(int tokenId) => _idToToken[tokenId];
+
///
/// Mines chain buckets from the provided training examples using the model's tokenizer,
/// builds a , attaches it to this model, and returns it.
@@ -141,6 +151,11 @@ public void LoadBucketTable(ChainBucketTable table)
{
ArgumentNullException.ThrowIfNull(table);
BucketTable = table;
+
+ if (Options.EnableRecallHeatMap)
+ {
+ _recallHeatMap = new BucketRecallHeatMap(Config.VocabSize);
+ }
}
public static BitNetPaperModel CreateDefault(
@@ -194,6 +209,8 @@ public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = n
{
lock (_gate)
{
+ _recallHeatMap?.ResetGenerationState();
+
var diagnostics = new List();
var contextTokenIds = TokenizeToIds(prompt).ToList();
var generatedTokenIds = new List();
@@ -293,6 +310,7 @@ public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = n
if (matchedPrefixLen > 0)
{
attemptedChains++;
+ _recallHeatMap?.RecordChainAttempt(chain.ChainId, chain.TokenIds, matchedPrefixLen);
var acceptedTokensForChain = 0;
for (var ci = matchedPrefixLen; ci < chain.TokenIds.Length && step < maxGeneratedTokens - 1; ci++)
{
@@ -322,6 +340,7 @@ public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = n
step++;
acceptedTokensForChain++;
acceptedChainTokens++;
+ _recallHeatMap?.RecordTokenAccepted(chain.ChainId, speculativeId);
if (Options.Verbosity == VerbosityLevel.Verbose)
{
@@ -333,6 +352,7 @@ public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = n
if (acceptedTokensForChain > 0)
{
acceptedChains++;
+ _recallHeatMap?.RecordChainAccepted(chain.ChainId);
}
}
}
diff --git a/src/BitNetSharp.Core/BitNetSharp.Core.csproj b/src/BitNetSharp.Core/BitNetSharp.Core.csproj
index d27ba3b..6da77b1 100644
--- a/src/BitNetSharp.Core/BitNetSharp.Core.csproj
+++ b/src/BitNetSharp.Core/BitNetSharp.Core.csproj
@@ -17,4 +17,12 @@
+
+
+
+
diff --git a/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMap.cs b/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMap.cs
new file mode 100644
index 0000000..43ee462
--- /dev/null
+++ b/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMap.cs
@@ -0,0 +1,438 @@
+namespace BitNetSharp.Core.Bucketing;
+
+public sealed record HeatMapCounters(
+ int VocabSize,
+ long[] AttemptCounts,
+ long[] AcceptCounts,
+ long[] ChainAttemptCounts,
+ long[] ChainAcceptCounts,
+ long[,] ChainTransitions);
+
+public sealed record TokenRecallEntry(int TokenId, long AttemptCount, long AcceptCount, double RecallRate);
+
+public sealed record BucketRecallRanking(
+ byte ChainId,
+ double AggregateRecallRate,
+ long TotalAcceptCount,
+ long TotalAttemptCount,
+ bool OnHotPath);
+
+public sealed record ChainTransitionEntry(byte ChainId, long TransitionCount);
+
+public sealed record ChainHotPath(byte[] ChainSequence, long MinTransitionCount);
+
+public sealed class BucketRecallHeatMap
+{
+ private readonly long[] _attemptCounts;
+ private readonly long[] _acceptCounts;
+ private readonly long[] _chainAttemptCounts;
+ private readonly long[] _chainAcceptCounts;
+ private readonly long[,] _chainTransitions;
+ private byte? _lastAcceptedChainId;
+
+ public int VocabSize { get; }
+
+ public BucketRecallHeatMap(int vocabSize)
+ {
+ ArgumentOutOfRangeException.ThrowIfNegativeOrZero(vocabSize);
+
+ VocabSize = vocabSize;
+ _attemptCounts = new long[vocabSize];
+ _acceptCounts = new long[vocabSize];
+ _chainAttemptCounts = new long[ChainBucketTable.MaxBuckets];
+ _chainAcceptCounts = new long[ChainBucketTable.MaxBuckets];
+ _chainTransitions = new long[ChainBucketTable.MaxBuckets, ChainBucketTable.MaxBuckets];
+ }
+
+ public void RecordChainAttempt(byte chainId, int[] tokenIds, int speculativeStartIndex)
+ {
+ ArgumentNullException.ThrowIfNull(tokenIds);
+
+ _chainAttemptCounts[chainId]++;
+ for (var i = speculativeStartIndex; i < tokenIds.Length; i++)
+ {
+ var tokenId = tokenIds[i];
+ if ((uint)tokenId < (uint)VocabSize)
+ {
+ _attemptCounts[tokenId]++;
+ }
+ }
+ }
+
+ public void RecordTokenAccepted(byte chainId, int tokenId)
+ {
+ if ((uint)tokenId < (uint)VocabSize)
+ {
+ _acceptCounts[tokenId]++;
+ }
+ }
+
+ public void RecordChainAccepted(byte chainId)
+ {
+ _chainAcceptCounts[chainId]++;
+ if (_lastAcceptedChainId.HasValue)
+ {
+ _chainTransitions[_lastAcceptedChainId.Value, chainId]++;
+ }
+
+ _lastAcceptedChainId = chainId;
+ }
+
+ public void ResetGenerationState()
+ {
+ _lastAcceptedChainId = null;
+ }
+
+ public long GetAttemptCount(int tokenId) =>
+ (uint)tokenId < (uint)VocabSize ? _attemptCounts[tokenId] : 0L;
+
+ public long GetAcceptCount(int tokenId) =>
+ (uint)tokenId < (uint)VocabSize ? _acceptCounts[tokenId] : 0L;
+
+ public double GetTokenRecallRate(int tokenId)
+ {
+ var attempts = GetAttemptCount(tokenId);
+ return attempts == 0L ? 0d : (double)GetAcceptCount(tokenId) / attempts;
+ }
+
+ public long GetChainAttemptCount(byte chainId) => _chainAttemptCounts[chainId];
+
+ public long GetChainAcceptCount(byte chainId) => _chainAcceptCounts[chainId];
+
+ public double GetChainRecallRate(byte chainId)
+ {
+ var attempts = _chainAttemptCounts[chainId];
+ return attempts == 0L ? 0d : (double)_chainAcceptCounts[chainId] / attempts;
+ }
+
+ public long GetTransitionCount(byte fromChainId, byte toChainId) =>
+ _chainTransitions[fromChainId, toChainId];
+
+ public IReadOnlyList GetIncomingChains(byte chainId, int maxResults = 10)
+ {
+ var entries = new List();
+ for (var from = 0; from < ChainBucketTable.MaxBuckets; from++)
+ {
+ var count = _chainTransitions[from, chainId];
+ if (count > 0)
+ {
+ entries.Add(new ChainTransitionEntry((byte)from, count));
+ }
+ }
+
+ return entries
+ .OrderByDescending(static e => e.TransitionCount)
+ .Take(maxResults)
+ .ToArray();
+ }
+
+ public IReadOnlyList GetOutgoingChains(byte chainId, int maxResults = 10)
+ {
+ var entries = new List();
+ for (var to = 0; to < ChainBucketTable.MaxBuckets; to++)
+ {
+ var count = _chainTransitions[chainId, to];
+ if (count > 0)
+ {
+ entries.Add(new ChainTransitionEntry((byte)to, count));
+ }
+ }
+
+ return entries
+ .OrderByDescending(static e => e.TransitionCount)
+ .Take(maxResults)
+ .ToArray();
+ }
+
+ public IReadOnlyList GetHotPaths(int maxDepth = 5, int maxResults = 10, long minTransitions = 2)
+ {
+ var paths = new List();
+ var visited = new HashSet();
+
+ for (var start = 0; start < ChainBucketTable.MaxBuckets; start++)
+ {
+ if (_chainAcceptCounts[start] == 0)
+ {
+ continue;
+ }
+
+ visited.Clear();
+ var sequence = new List { (byte)start };
+ visited.Add((byte)start);
+ var bottleneck = long.MaxValue;
+
+ var current = (byte)start;
+ for (var depth = 1; depth < maxDepth; depth++)
+ {
+ var bestNext = -1;
+ var bestCount = 0L;
+ for (var next = 0; next < ChainBucketTable.MaxBuckets; next++)
+ {
+ var count = _chainTransitions[current, next];
+ if (count >= minTransitions && count > bestCount && !visited.Contains((byte)next))
+ {
+ bestNext = next;
+ bestCount = count;
+ }
+ }
+
+ if (bestNext < 0)
+ {
+ break;
+ }
+
+ sequence.Add((byte)bestNext);
+ visited.Add((byte)bestNext);
+ bottleneck = Math.Min(bottleneck, bestCount);
+ current = (byte)bestNext;
+ }
+
+ if (sequence.Count >= 2)
+ {
+ paths.Add(new ChainHotPath(sequence.ToArray(), bottleneck));
+ }
+ }
+
+ // Remove sub-paths: if path A is a prefix of path B, keep only B.
+ var deduplicated = new List();
+ var sortedByLength = paths.OrderByDescending(static p => p.ChainSequence.Length).ToArray();
+ foreach (var path in sortedByLength)
+ {
+ var isSubPath = false;
+ foreach (var longer in deduplicated)
+ {
+ if (IsSubsequence(path.ChainSequence, longer.ChainSequence))
+ {
+ isSubPath = true;
+ break;
+ }
+ }
+
+ if (!isSubPath)
+ {
+ deduplicated.Add(path);
+ }
+ }
+
+ return deduplicated
+ .OrderByDescending(static p => p.MinTransitionCount)
+ .Take(maxResults)
+ .ToArray();
+ }
+
+ public IReadOnlyList GetTopTokensByAcceptCount(int maxResults = 20)
+ {
+ var entries = new List();
+ for (var i = 0; i < VocabSize; i++)
+ {
+ if (_acceptCounts[i] > 0)
+ {
+ entries.Add(new TokenRecallEntry(i, _attemptCounts[i], _acceptCounts[i], GetTokenRecallRate(i)));
+ }
+ }
+
+ return entries
+ .OrderByDescending(static e => e.AcceptCount)
+ .Take(maxResults)
+ .ToArray();
+ }
+
+ public IReadOnlyList GetTopTokensByRecallRate(int maxResults = 20, int minAttempts = 1)
+ {
+ var entries = new List();
+ for (var i = 0; i < VocabSize; i++)
+ {
+ if (_attemptCounts[i] >= minAttempts)
+ {
+ entries.Add(new TokenRecallEntry(i, _attemptCounts[i], _acceptCounts[i], GetTokenRecallRate(i)));
+ }
+ }
+
+ return entries
+ .OrderByDescending(static e => e.RecallRate)
+ .Take(maxResults)
+ .ToArray();
+ }
+
+ public IReadOnlyList RankBucketsForCompaction(ChainBucketTable table)
+ {
+ ArgumentNullException.ThrowIfNull(table);
+
+ var hotPathChainIds = GetHotPathChainIds();
+ var rankings = new List();
+
+ foreach (var bucket in table.Buckets)
+ {
+ var attempts = _chainAttemptCounts[bucket.ChainId];
+ var accepts = _chainAcceptCounts[bucket.ChainId];
+ var recallRate = attempts == 0L ? 0d : (double)accepts / attempts;
+ var onHotPath = hotPathChainIds.Contains(bucket.ChainId);
+
+ rankings.Add(new BucketRecallRanking(
+ bucket.ChainId,
+ recallRate,
+ accepts,
+ attempts,
+ onHotPath));
+ }
+
+ // Sort worst-first: non-hot-path chains with low recall come first (pruning candidates).
+ return rankings
+ .OrderBy(static r => r.OnHotPath)
+ .ThenBy(static r => r.AggregateRecallRate)
+ .ThenBy(static r => r.TotalAcceptCount)
+ .ToArray();
+ }
+
+ public IReadOnlySet IdentifyLowValueBuckets(ChainBucketTable table, double threshold = 0.5, int minAttempts = 2)
+ {
+ ArgumentNullException.ThrowIfNull(table);
+
+ var hotPathChainIds = GetHotPathChainIds();
+ var lowValue = new HashSet();
+
+ foreach (var bucket in table.Buckets)
+ {
+ if (hotPathChainIds.Contains(bucket.ChainId))
+ {
+ continue;
+ }
+
+ var attempts = _chainAttemptCounts[bucket.ChainId];
+ if (attempts < minAttempts)
+ {
+ lowValue.Add(bucket.ChainId);
+ continue;
+ }
+
+ var recallRate = (double)_chainAcceptCounts[bucket.ChainId] / attempts;
+ if (recallRate < threshold)
+ {
+ lowValue.Add(bucket.ChainId);
+ }
+ }
+
+ return lowValue;
+ }
+
+ public void Reset()
+ {
+ Array.Clear(_attemptCounts);
+ Array.Clear(_acceptCounts);
+ Array.Clear(_chainAttemptCounts);
+ Array.Clear(_chainAcceptCounts);
+ Array.Clear(_chainTransitions);
+ _lastAcceptedChainId = null;
+ }
+
+ public void MergeFrom(BucketRecallHeatMap other)
+ {
+ ArgumentNullException.ThrowIfNull(other);
+ if (other.VocabSize != VocabSize)
+ {
+ throw new ArgumentException(
+ $"Cannot merge heat maps with different vocab sizes ({VocabSize} vs {other.VocabSize}).",
+ nameof(other));
+ }
+
+ for (var i = 0; i < VocabSize; i++)
+ {
+ _attemptCounts[i] += other._attemptCounts[i];
+ _acceptCounts[i] += other._acceptCounts[i];
+ }
+
+ for (var i = 0; i < ChainBucketTable.MaxBuckets; i++)
+ {
+ _chainAttemptCounts[i] += other._chainAttemptCounts[i];
+ _chainAcceptCounts[i] += other._chainAcceptCounts[i];
+ for (var j = 0; j < ChainBucketTable.MaxBuckets; j++)
+ {
+ _chainTransitions[i, j] += other._chainTransitions[i, j];
+ }
+ }
+ }
+
+ public HeatMapCounters ExportCounters()
+ {
+ var attemptCounts = (long[])_attemptCounts.Clone();
+ var acceptCounts = (long[])_acceptCounts.Clone();
+ var chainAttemptCounts = (long[])_chainAttemptCounts.Clone();
+ var chainAcceptCounts = (long[])_chainAcceptCounts.Clone();
+ var chainTransitions = (long[,])_chainTransitions.Clone();
+ return new HeatMapCounters(VocabSize, attemptCounts, acceptCounts, chainAttemptCounts, chainAcceptCounts, chainTransitions);
+ }
+
+ public static BucketRecallHeatMap FromCounters(HeatMapCounters counters)
+ {
+ ArgumentNullException.ThrowIfNull(counters);
+
+ if (counters.AttemptCounts.Length != counters.VocabSize
+ || counters.AcceptCounts.Length != counters.VocabSize)
+ {
+ throw new ArgumentException("Token counter arrays must match VocabSize.", nameof(counters));
+ }
+
+ if (counters.ChainAttemptCounts.Length != ChainBucketTable.MaxBuckets
+ || counters.ChainAcceptCounts.Length != ChainBucketTable.MaxBuckets)
+ {
+ throw new ArgumentException($"Chain counter arrays must have length {ChainBucketTable.MaxBuckets}.", nameof(counters));
+ }
+
+ if (counters.ChainTransitions.GetLength(0) != ChainBucketTable.MaxBuckets
+ || counters.ChainTransitions.GetLength(1) != ChainBucketTable.MaxBuckets)
+ {
+ throw new ArgumentException($"Chain transitions matrix must be {ChainBucketTable.MaxBuckets}x{ChainBucketTable.MaxBuckets}.", nameof(counters));
+ }
+
+ var heatMap = new BucketRecallHeatMap(counters.VocabSize);
+ Array.Copy(counters.AttemptCounts, heatMap._attemptCounts, counters.VocabSize);
+ Array.Copy(counters.AcceptCounts, heatMap._acceptCounts, counters.VocabSize);
+ Array.Copy(counters.ChainAttemptCounts, heatMap._chainAttemptCounts, ChainBucketTable.MaxBuckets);
+ Array.Copy(counters.ChainAcceptCounts, heatMap._chainAcceptCounts, ChainBucketTable.MaxBuckets);
+ Array.Copy(counters.ChainTransitions, heatMap._chainTransitions, ChainBucketTable.MaxBuckets * ChainBucketTable.MaxBuckets);
+ return heatMap;
+ }
+
+ private HashSet GetHotPathChainIds()
+ {
+ var hotPaths = GetHotPaths();
+ var chainIds = new HashSet();
+ foreach (var path in hotPaths)
+ {
+ foreach (var chainId in path.ChainSequence)
+ {
+ chainIds.Add(chainId);
+ }
+ }
+
+ return chainIds;
+ }
+
+ private static bool IsSubsequence(byte[] candidate, byte[] longer)
+ {
+ if (candidate.Length >= longer.Length)
+ {
+ return false;
+ }
+
+ for (var start = 0; start <= longer.Length - candidate.Length; start++)
+ {
+ var match = true;
+ for (var i = 0; i < candidate.Length; i++)
+ {
+ if (longer[start + i] != candidate[i])
+ {
+ match = false;
+ break;
+ }
+ }
+
+ if (match)
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+}
diff --git a/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMapSerializer.cs b/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMapSerializer.cs
new file mode 100644
index 0000000..9d97468
--- /dev/null
+++ b/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMapSerializer.cs
@@ -0,0 +1,192 @@
+using System.Buffers.Binary;
+using System.Text;
+
+namespace BitNetSharp.Core.Bucketing;
+
+///
+/// Provides a compact binary serializer for sidecar persistence.
+/// The wire format is little-endian and follows the repository's recall-heatmap.bin v1 layout.
+///
+public static class BucketRecallHeatMapSerializer
+{
+ private static readonly byte[] MagicHeader = Encoding.ASCII.GetBytes("BRHM");
+ private const ushort FormatVersion = 1;
+ private const int HeaderLength = 12;
+ private const int FooterLength = 4;
+
+ /// Saves a recall heat map to a binary file.
+ public static void Save(BucketRecallHeatMap heatMap, string path)
+ {
+ ArgumentNullException.ThrowIfNull(heatMap);
+ ArgumentException.ThrowIfNullOrWhiteSpace(path);
+
+ var directory = Path.GetDirectoryName(path);
+ if (!string.IsNullOrWhiteSpace(directory))
+ {
+ Directory.CreateDirectory(directory);
+ }
+
+ using var stream = File.Create(path);
+ Serialize(heatMap, stream);
+ }
+
+ /// Loads a recall heat map from a binary file.
+ public static BucketRecallHeatMap Load(string path)
+ {
+ ArgumentException.ThrowIfNullOrWhiteSpace(path);
+
+ using var stream = File.OpenRead(path);
+ return Deserialize(stream);
+ }
+
+ /// Writes a recall heat map to a stream using the binary sidecar format.
+ public static void Serialize(BucketRecallHeatMap heatMap, Stream destination)
+ {
+ ArgumentNullException.ThrowIfNull(heatMap);
+ ArgumentNullException.ThrowIfNull(destination);
+
+ var counters = heatMap.ExportCounters();
+
+ using var payload = new MemoryStream();
+ using (var writer = new BinaryWriter(payload, Encoding.ASCII, leaveOpen: true))
+ {
+ writer.Write(MagicHeader);
+ writer.Write(FormatVersion);
+ writer.Write((ushort)0);
+ writer.Write(counters.VocabSize);
+
+ for (var i = 0; i < counters.VocabSize; i++)
+ {
+ writer.Write(counters.AttemptCounts[i]);
+ }
+
+ for (var i = 0; i < counters.VocabSize; i++)
+ {
+ writer.Write(counters.AcceptCounts[i]);
+ }
+
+ for (var i = 0; i < ChainBucketTable.MaxBuckets; i++)
+ {
+ writer.Write(counters.ChainAttemptCounts[i]);
+ }
+
+ for (var i = 0; i < ChainBucketTable.MaxBuckets; i++)
+ {
+ writer.Write(counters.ChainAcceptCounts[i]);
+ }
+
+ for (var row = 0; row < ChainBucketTable.MaxBuckets; row++)
+ {
+ for (var col = 0; col < ChainBucketTable.MaxBuckets; col++)
+ {
+ writer.Write(counters.ChainTransitions[row, col]);
+ }
+ }
+ }
+
+ var buffer = payload.ToArray();
+ var checksum = ComputeCrc32(buffer);
+ destination.Write(buffer, 0, buffer.Length);
+
+ Span checksumBuffer = stackalloc byte[FooterLength];
+ BinaryPrimitives.WriteUInt32LittleEndian(checksumBuffer, checksum);
+ destination.Write(checksumBuffer);
+ }
+
+ /// Reads a recall heat map from a stream using the binary sidecar format.
+ public static BucketRecallHeatMap Deserialize(Stream source)
+ {
+ ArgumentNullException.ThrowIfNull(source);
+
+ using var payload = new MemoryStream();
+ source.CopyTo(payload);
+ var bytes = payload.ToArray();
+
+ if (bytes.Length < HeaderLength + FooterLength)
+ {
+ throw new InvalidDataException("Recall heat map binary payload is too short.");
+ }
+
+ var expectedChecksum = BinaryPrimitives.ReadUInt32LittleEndian(bytes.AsSpan(bytes.Length - FooterLength, FooterLength));
+ var actualChecksum = ComputeCrc32(bytes.AsSpan(0, bytes.Length - FooterLength));
+ if (actualChecksum != expectedChecksum)
+ {
+ throw new InvalidDataException("Recall heat map CRC32 checksum mismatch.");
+ }
+
+ using var reader = new BinaryReader(new MemoryStream(bytes, writable: false), Encoding.ASCII, leaveOpen: false);
+ var magic = reader.ReadBytes(MagicHeader.Length);
+ if (!magic.AsSpan().SequenceEqual(MagicHeader))
+ {
+ throw new InvalidDataException("Unsupported recall heat map binary header.");
+ }
+
+ var version = reader.ReadUInt16();
+ _ = reader.ReadUInt16();
+ var vocabSize = reader.ReadInt32();
+
+ if (version != FormatVersion)
+ {
+ throw new InvalidDataException($"Unsupported recall heat map version {version}.");
+ }
+
+ if (vocabSize <= 0)
+ {
+ throw new InvalidDataException($"Invalid vocab size {vocabSize} in recall heat map.");
+ }
+
+ var attemptCounts = new long[vocabSize];
+ for (var i = 0; i < vocabSize; i++)
+ {
+ attemptCounts[i] = reader.ReadInt64();
+ }
+
+ var acceptCounts = new long[vocabSize];
+ for (var i = 0; i < vocabSize; i++)
+ {
+ acceptCounts[i] = reader.ReadInt64();
+ }
+
+ var chainAttemptCounts = new long[ChainBucketTable.MaxBuckets];
+ for (var i = 0; i < ChainBucketTable.MaxBuckets; i++)
+ {
+ chainAttemptCounts[i] = reader.ReadInt64();
+ }
+
+ var chainAcceptCounts = new long[ChainBucketTable.MaxBuckets];
+ for (var i = 0; i < ChainBucketTable.MaxBuckets; i++)
+ {
+ chainAcceptCounts[i] = reader.ReadInt64();
+ }
+
+ var chainTransitions = new long[ChainBucketTable.MaxBuckets, ChainBucketTable.MaxBuckets];
+ for (var row = 0; row < ChainBucketTable.MaxBuckets; row++)
+ {
+ for (var col = 0; col < ChainBucketTable.MaxBuckets; col++)
+ {
+ chainTransitions[row, col] = reader.ReadInt64();
+ }
+ }
+
+ var counters = new HeatMapCounters(vocabSize, attemptCounts, acceptCounts, chainAttemptCounts, chainAcceptCounts, chainTransitions);
+ return BucketRecallHeatMap.FromCounters(counters);
+ }
+
+ private static uint ComputeCrc32(ReadOnlySpan data)
+ {
+ const uint polynomial = 0xEDB88320u;
+ var crc = 0xFFFFFFFFu;
+
+ foreach (var value in data)
+ {
+ crc ^= value;
+ for (var bit = 0; bit < 8; bit++)
+ {
+ var mask = (crc & 1u) == 0u ? 0u : polynomial;
+ crc = (crc >> 1) ^ mask;
+ }
+ }
+
+ return ~crc;
+ }
+}
diff --git a/src/BitNetSharp.Core/Bucketing/BucketRecallVisualizer.cs b/src/BitNetSharp.Core/Bucketing/BucketRecallVisualizer.cs
new file mode 100644
index 0000000..98ad684
--- /dev/null
+++ b/src/BitNetSharp.Core/Bucketing/BucketRecallVisualizer.cs
@@ -0,0 +1,212 @@
+using System.Text;
+
+namespace BitNetSharp.Core.Bucketing;
+
+///
+/// Renders bucket recall heat map data as Mermaid diagrams for visualization
+/// in markdown, GitHub, or any Mermaid-compatible renderer.
+///
+public static class BucketRecallVisualizer
+{
+ ///
+ /// Renders a Mermaid xychart-beta bar chart of the top tokens by accept count.
+ ///
+ public static string RenderTokenHeatMap(
+ BucketRecallHeatMap heatMap,
+ Func tokenResolver,
+ int maxTokens = 15)
+ {
+ ArgumentNullException.ThrowIfNull(heatMap);
+ ArgumentNullException.ThrowIfNull(tokenResolver);
+
+ var topTokens = heatMap.GetTopTokensByAcceptCount(maxTokens);
+ var sb = new StringBuilder();
+ sb.AppendLine("```mermaid");
+ sb.AppendLine("xychart-beta");
+ sb.AppendLine(" title \"Token Recall Heat Map (by accept count)\"");
+
+ if (topTokens.Count == 0)
+ {
+ sb.AppendLine(" x-axis [\"(no data)\"]");
+ sb.AppendLine(" bar [0]");
+ }
+ else
+ {
+ var labels = string.Join(", ", topTokens.Select(t => $"\"{Escape(tokenResolver(t.TokenId))}\""));
+ var values = string.Join(", ", topTokens.Select(static t => t.AcceptCount));
+ sb.AppendLine($" x-axis [{labels}]");
+ sb.AppendLine($" bar [{values}]");
+ }
+
+ sb.AppendLine("```");
+ return sb.ToString();
+ }
+
+ ///
+ /// Renders a Mermaid xychart-beta bar chart of chain recall rates.
+ ///
+ public static string RenderChainRecallChart(
+ BucketRecallHeatMap heatMap,
+ ChainBucketTable table,
+ Func tokenResolver,
+ int maxChains = 20)
+ {
+ ArgumentNullException.ThrowIfNull(heatMap);
+ ArgumentNullException.ThrowIfNull(table);
+ ArgumentNullException.ThrowIfNull(tokenResolver);
+
+ var chains = table.Buckets
+ .Where(b => heatMap.GetChainAttemptCount(b.ChainId) > 0)
+ .OrderByDescending(b => heatMap.GetChainRecallRate(b.ChainId))
+ .Take(maxChains)
+ .ToArray();
+
+ var sb = new StringBuilder();
+ sb.AppendLine("```mermaid");
+ sb.AppendLine("xychart-beta");
+ sb.AppendLine(" title \"Chain Recall Rate\"");
+
+ if (chains.Length == 0)
+ {
+ sb.AppendLine(" x-axis [\"(no data)\"]");
+ sb.AppendLine(" bar [0]");
+ }
+ else
+ {
+ var labels = string.Join(", ", chains.Select(c =>
+ $"\"{Escape(FormatChainLabel(c, tokenResolver))}\""));
+ var values = string.Join(", ", chains.Select(c =>
+ (heatMap.GetChainRecallRate(c.ChainId) * 100).ToString("F0")));
+ sb.AppendLine($" x-axis [{labels}]");
+ sb.AppendLine(" y-axis \"Recall %\"");
+ sb.AppendLine($" bar [{values}]");
+ }
+
+ sb.AppendLine("```");
+ return sb.ToString();
+ }
+
+ ///
+ /// Renders a Mermaid flowchart LR showing hot-path chain sequences with transition counts on edges.
+ ///
+ public static string RenderHotPathDiagram(
+ BucketRecallHeatMap heatMap,
+ ChainBucketTable table,
+ Func tokenResolver,
+ int maxDepth = 5,
+ int maxPaths = 5)
+ {
+ ArgumentNullException.ThrowIfNull(heatMap);
+ ArgumentNullException.ThrowIfNull(table);
+ ArgumentNullException.ThrowIfNull(tokenResolver);
+
+ var hotPaths = heatMap.GetHotPaths(maxDepth, maxPaths);
+
+ var sb = new StringBuilder();
+ sb.AppendLine("```mermaid");
+ sb.AppendLine("flowchart LR");
+
+ if (hotPaths.Count == 0)
+ {
+ sb.AppendLine(" empty[\"No hot-paths detected\"]");
+ }
+ else
+ {
+ var declaredNodes = new HashSet();
+ foreach (var path in hotPaths)
+ {
+ foreach (var chainId in path.ChainSequence)
+ {
+ if (declaredNodes.Add(chainId))
+ {
+ var bucket = table.GetById(chainId);
+ var label = bucket is not null
+ ? FormatChainLabel(bucket, tokenResolver)
+ : $"chain {chainId}";
+ var accepts = heatMap.GetChainAcceptCount(chainId);
+ sb.AppendLine($" C{chainId}[\"{Escape(label)}
accepts: {accepts}\"]");
+ }
+ }
+
+ for (var i = 0; i < path.ChainSequence.Length - 1; i++)
+ {
+ var from = path.ChainSequence[i];
+ var to = path.ChainSequence[i + 1];
+ var count = heatMap.GetTransitionCount(from, to);
+ sb.AppendLine($" C{from} -->|{count}| C{to}");
+ }
+ }
+ }
+
+ sb.AppendLine("```");
+ return sb.ToString();
+ }
+
+ ///
+ /// Renders a Mermaid flowchart showing all chains with color coding:
+ /// green for hot-path chains, red for low-value compaction candidates.
+ ///
+ public static string RenderCompactionReport(
+ BucketRecallHeatMap heatMap,
+ ChainBucketTable table,
+ Func tokenResolver,
+ double threshold = 0.5)
+ {
+ ArgumentNullException.ThrowIfNull(heatMap);
+ ArgumentNullException.ThrowIfNull(table);
+ ArgumentNullException.ThrowIfNull(tokenResolver);
+
+ var rankings = heatMap.RankBucketsForCompaction(table);
+ var lowValue = heatMap.IdentifyLowValueBuckets(table, threshold);
+
+ var sb = new StringBuilder();
+ sb.AppendLine("```mermaid");
+ sb.AppendLine("flowchart TD");
+
+ var greenNodes = new List();
+ var redNodes = new List();
+
+ foreach (var ranking in rankings)
+ {
+ var bucket = table.GetById(ranking.ChainId);
+ var label = bucket is not null
+ ? FormatChainLabel(bucket, tokenResolver)
+ : $"chain {ranking.ChainId}";
+ var rate = ranking.AggregateRecallRate * 100;
+ var nodeId = $"C{ranking.ChainId}";
+
+ sb.AppendLine($" {nodeId}[\"{Escape(label)}
recall: {rate:F0}% accepts: {ranking.TotalAcceptCount}\"]");
+
+ if (ranking.OnHotPath)
+ {
+ greenNodes.Add(nodeId);
+ }
+ else if (lowValue.Contains(ranking.ChainId))
+ {
+ redNodes.Add(nodeId);
+ }
+ }
+
+ if (greenNodes.Count > 0)
+ {
+ sb.AppendLine($" style {string.Join(",", greenNodes)} fill:#4c4,stroke:#393,color:#fff");
+ }
+
+ if (redNodes.Count > 0)
+ {
+ sb.AppendLine($" style {string.Join(",", redNodes)} fill:#f44,stroke:#c33,color:#fff");
+ }
+
+ sb.AppendLine("```");
+ return sb.ToString();
+ }
+
+ private static string FormatChainLabel(ChainBucket bucket, Func tokenResolver)
+ {
+ var tokens = bucket.TokenIds.Select(tokenResolver);
+ return string.Join(" ", tokens);
+ }
+
+ private static string Escape(string text) =>
+ text.Replace("\"", "#quot;", StringComparison.Ordinal);
+}
diff --git a/src/BitNetSharp.Core/Serialization/BitNetPaperModelSnapshot.cs b/src/BitNetSharp.Core/Serialization/BitNetPaperModelSnapshot.cs
index ebf33bf..56dd6dd 100644
--- a/src/BitNetSharp.Core/Serialization/BitNetPaperModelSnapshot.cs
+++ b/src/BitNetSharp.Core/Serialization/BitNetPaperModelSnapshot.cs
@@ -16,7 +16,8 @@ internal sealed record BitNetPaperModelSnapshot(
IReadOnlyList TransformerProjectionWeights,
IReadOnlyList NormScales,
float[,] OutputHeadWeights,
- IReadOnlyDictionary MemorizedResponses)
+ IReadOnlyDictionary MemorizedResponses,
+ bool EnableRecallHeatMap = true)
{
internal const int DefaultBootstrapSeed = 42;
@@ -38,7 +39,8 @@ public static BitNetPaperModelSnapshot Capture(BitNetPaperModel model)
model.ExportTransformerProjectionWeights().Select(CloneMatrix).ToArray(),
model.ExportNormScales().Select(CloneVector).ToArray(),
CloneMatrix(model.ExportOutputHeadWeights()),
- CloneMemorizedResponses(model.ExportMemorizedResponses()));
+ CloneMemorizedResponses(model.ExportMemorizedResponses()),
+ model.Options.EnableRecallHeatMap);
}
public BitNetPaperModel Restore(VerbosityLevel verbosity = VerbosityLevel.Normal)
@@ -51,7 +53,8 @@ public BitNetPaperModel Restore(VerbosityLevel verbosity = VerbosityLevel.Normal
PrimaryLanguage,
EnableChainBuckets,
EnableSequenceCompression,
- ChainBucketAcceptanceThreshold),
+ ChainBucketAcceptanceThreshold,
+ EnableRecallHeatMap),
Config,
BootstrapSeed);
diff --git a/tests/BitNetSharp.Tests/BucketRecallHeatMapSerializerTests.cs b/tests/BitNetSharp.Tests/BucketRecallHeatMapSerializerTests.cs
new file mode 100644
index 0000000..ae556f2
--- /dev/null
+++ b/tests/BitNetSharp.Tests/BucketRecallHeatMapSerializerTests.cs
@@ -0,0 +1,152 @@
+using BitNetSharp.Core.Bucketing;
+
+namespace BitNetSharp.Tests;
+
+public sealed class BucketRecallHeatMapSerializerTests
+{
+ [Fact]
+ public void RoundTrip_PreservesAllCountersAndTransitions()
+ {
+ var original = new BucketRecallHeatMap(64);
+
+ original.RecordChainAttempt(0, [5, 10, 15], speculativeStartIndex: 1);
+ original.RecordChainAttempt(0, [5, 10, 15], speculativeStartIndex: 1);
+ original.RecordTokenAccepted(0, 10);
+ original.RecordChainAccepted(0);
+ original.RecordChainAccepted(3);
+ original.ResetGenerationState();
+ original.RecordChainAccepted(0);
+ original.RecordChainAccepted(3);
+
+ using var stream = new MemoryStream();
+ BucketRecallHeatMapSerializer.Serialize(original, stream);
+ stream.Position = 0;
+ var restored = BucketRecallHeatMapSerializer.Deserialize(stream);
+
+ Assert.Equal(original.VocabSize, restored.VocabSize);
+ Assert.Equal(original.GetAttemptCount(10), restored.GetAttemptCount(10));
+ Assert.Equal(original.GetAttemptCount(15), restored.GetAttemptCount(15));
+ Assert.Equal(original.GetAcceptCount(10), restored.GetAcceptCount(10));
+ Assert.Equal(original.GetChainAttemptCount(0), restored.GetChainAttemptCount(0));
+ Assert.Equal(original.GetChainAcceptCount(0), restored.GetChainAcceptCount(0));
+ Assert.Equal(original.GetChainAcceptCount(3), restored.GetChainAcceptCount(3));
+ Assert.Equal(original.GetTransitionCount(0, 3), restored.GetTransitionCount(0, 3));
+ Assert.Equal(2, restored.GetTransitionCount(0, 3));
+ }
+
+ [Fact]
+ public void Deserialize_RejectsCorruptedChecksum()
+ {
+ var original = new BucketRecallHeatMap(16);
+ original.RecordChainAttempt(0, [1, 2], speculativeStartIndex: 0);
+
+ using var stream = new MemoryStream();
+ BucketRecallHeatMapSerializer.Serialize(original, stream);
+ var bytes = stream.ToArray();
+ bytes[20] ^= 0x01;
+
+ using var corrupted = new MemoryStream(bytes);
+ var exception = Assert.Throws(() => BucketRecallHeatMapSerializer.Deserialize(corrupted));
+ Assert.Contains("CRC32", exception.Message, StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void Deserialize_RejectsInvalidMagic()
+ {
+ var original = new BucketRecallHeatMap(16);
+
+ using var stream = new MemoryStream();
+ BucketRecallHeatMapSerializer.Serialize(original, stream);
+ var bytes = stream.ToArray();
+ bytes[0] = (byte)'X';
+
+ // Recompute CRC32 so the checksum doesn't fail first.
+ var crc = ComputeCrc32(bytes.AsSpan(0, bytes.Length - 4));
+ BitConverter.TryWriteBytes(bytes.AsSpan(bytes.Length - 4), crc);
+
+ using var corrupted = new MemoryStream(bytes);
+ var exception = Assert.Throws(() => BucketRecallHeatMapSerializer.Deserialize(corrupted));
+ Assert.Contains("header", exception.Message, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact]
+ public void Deserialize_RejectsVersionMismatch()
+ {
+ var original = new BucketRecallHeatMap(16);
+
+ using var stream = new MemoryStream();
+ BucketRecallHeatMapSerializer.Serialize(original, stream);
+ var bytes = stream.ToArray();
+ // Version is at offset 4 (2 bytes LE). Set to 99.
+ bytes[4] = 99;
+ bytes[5] = 0;
+
+ var crc = ComputeCrc32(bytes.AsSpan(0, bytes.Length - 4));
+ BitConverter.TryWriteBytes(bytes.AsSpan(bytes.Length - 4), crc);
+
+ using var corrupted = new MemoryStream(bytes);
+ var exception = Assert.Throws(() => BucketRecallHeatMapSerializer.Deserialize(corrupted));
+ Assert.Contains("version", exception.Message, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact]
+ public void RoundTripWithEmptyHeatMap()
+ {
+ var original = new BucketRecallHeatMap(32);
+
+ using var stream = new MemoryStream();
+ BucketRecallHeatMapSerializer.Serialize(original, stream);
+ stream.Position = 0;
+ var restored = BucketRecallHeatMapSerializer.Deserialize(stream);
+
+ Assert.Equal(32, restored.VocabSize);
+ Assert.Equal(0, restored.GetAttemptCount(0));
+ Assert.Equal(0, restored.GetChainAcceptCount(0));
+ Assert.Equal(0, restored.GetTransitionCount(0, 0));
+ }
+
+ [Fact]
+ public void FileRoundTrip_PreservesData()
+ {
+ var original = new BucketRecallHeatMap(32);
+ original.RecordChainAttempt(5, [1, 2, 3], 1);
+ original.RecordTokenAccepted(5, 2);
+ original.RecordChainAccepted(5);
+
+ var tempPath = Path.Combine(Path.GetTempPath(), $"heatmap-test-{Guid.NewGuid():N}.bin");
+ try
+ {
+ BucketRecallHeatMapSerializer.Save(original, tempPath);
+ var restored = BucketRecallHeatMapSerializer.Load(tempPath);
+
+ Assert.Equal(32, restored.VocabSize);
+ Assert.Equal(1, restored.GetAttemptCount(2));
+ Assert.Equal(1, restored.GetAcceptCount(2));
+ Assert.Equal(1, restored.GetChainAcceptCount(5));
+ }
+ finally
+ {
+ if (File.Exists(tempPath))
+ {
+ File.Delete(tempPath);
+ }
+ }
+ }
+
+ private static uint ComputeCrc32(ReadOnlySpan data)
+ {
+ const uint polynomial = 0xEDB88320u;
+ var crc = 0xFFFFFFFFu;
+ foreach (var value in data)
+ {
+ crc ^= value;
+ for (var bit = 0; bit < 8; bit++)
+ {
+ var mask = (crc & 1u) == 0u ? 0u : polynomial;
+ crc = (crc >> 1) ^ mask;
+ }
+ }
+
+ return ~crc;
+ }
+}
diff --git a/tests/BitNetSharp.Tests/BucketRecallHeatMapTests.cs b/tests/BitNetSharp.Tests/BucketRecallHeatMapTests.cs
new file mode 100644
index 0000000..25d6746
--- /dev/null
+++ b/tests/BitNetSharp.Tests/BucketRecallHeatMapTests.cs
@@ -0,0 +1,469 @@
+using BitNetSharp.Core.Bucketing;
+
+namespace BitNetSharp.Tests;
+
+public sealed class BucketRecallHeatMapTests
+{
+ private const int TestVocabSize = 32;
+
+ [Fact]
+ public void RecordChainAttempt_IncrementsAttemptCountsForSpeculativeTokens()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var tokenIds = new[] { 5, 10, 15, 20 };
+
+ heatMap.RecordChainAttempt(0, tokenIds, speculativeStartIndex: 2);
+
+ Assert.Equal(0, heatMap.GetAttemptCount(5));
+ Assert.Equal(0, heatMap.GetAttemptCount(10));
+ Assert.Equal(1, heatMap.GetAttemptCount(15));
+ Assert.Equal(1, heatMap.GetAttemptCount(20));
+ }
+
+ [Fact]
+ public void RecordChainAttempt_IncrementsChainAttemptCount()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordChainAttempt(3, [1, 2, 3], speculativeStartIndex: 1);
+ heatMap.RecordChainAttempt(3, [1, 2, 3], speculativeStartIndex: 1);
+
+ Assert.Equal(2, heatMap.GetChainAttemptCount(3));
+ }
+
+ [Fact]
+ public void RecordTokenAccepted_IncrementsAcceptCounters()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordTokenAccepted(0, 7);
+ heatMap.RecordTokenAccepted(0, 7);
+ heatMap.RecordTokenAccepted(0, 7);
+
+ Assert.Equal(3, heatMap.GetAcceptCount(7));
+ }
+
+ [Fact]
+ public void RecordChainAccepted_RecordsTransitionFromPreviousChain()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordChainAccepted(1);
+ heatMap.RecordChainAccepted(5);
+
+ Assert.Equal(1, heatMap.GetTransitionCount(1, 5));
+ }
+
+ [Fact]
+ public void RecordChainAccepted_NoTransitionOnFirstChainInGeneration()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordChainAccepted(3);
+
+ Assert.Equal(1, heatMap.GetChainAcceptCount(3));
+ // No transition should exist since there was no previous chain.
+ for (var from = 0; from < ChainBucketTable.MaxBuckets; from++)
+ {
+ Assert.Equal(0, heatMap.GetTransitionCount((byte)from, 3));
+ }
+ }
+
+ [Fact]
+ public void ResetGenerationState_ClearsLastAcceptedChainId()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordChainAccepted(1);
+ heatMap.ResetGenerationState();
+ heatMap.RecordChainAccepted(5);
+
+ // No transition from 1→5 because generation state was reset.
+ Assert.Equal(0, heatMap.GetTransitionCount(1, 5));
+ }
+
+ [Fact]
+ public void GetTokenRecallRate_ReturnsZeroWhenNeverAttempted()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ Assert.Equal(0d, heatMap.GetTokenRecallRate(10));
+ }
+
+ [Fact]
+ public void GetTokenRecallRate_ReturnsCorrectRatio()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordChainAttempt(0, [5], speculativeStartIndex: 0);
+ heatMap.RecordChainAttempt(0, [5], speculativeStartIndex: 0);
+ heatMap.RecordChainAttempt(0, [5], speculativeStartIndex: 0);
+ heatMap.RecordChainAttempt(0, [5], speculativeStartIndex: 0);
+ heatMap.RecordTokenAccepted(0, 5);
+ heatMap.RecordTokenAccepted(0, 5);
+ heatMap.RecordTokenAccepted(0, 5);
+
+ Assert.Equal(0.75d, heatMap.GetTokenRecallRate(5));
+ }
+
+ [Fact]
+ public void GetTransitionCount_ReturnsRecordedCount()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordChainAccepted(2);
+ heatMap.RecordChainAccepted(7);
+ heatMap.ResetGenerationState();
+ heatMap.RecordChainAccepted(2);
+ heatMap.RecordChainAccepted(7);
+
+ Assert.Equal(2, heatMap.GetTransitionCount(2, 7));
+ }
+
+ [Fact]
+ public void GetIncomingChains_ReturnsPredecessorsSortedByCount()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ // Chain 10 is reached from chain 1 (3 times) and chain 2 (1 time).
+ for (var i = 0; i < 3; i++)
+ {
+ heatMap.RecordChainAccepted(1);
+ heatMap.RecordChainAccepted(10);
+ heatMap.ResetGenerationState();
+ }
+
+ heatMap.RecordChainAccepted(2);
+ heatMap.RecordChainAccepted(10);
+
+ var incoming = heatMap.GetIncomingChains(10);
+ Assert.Equal(2, incoming.Count);
+ Assert.Equal(1, incoming[0].ChainId);
+ Assert.Equal(3, incoming[0].TransitionCount);
+ Assert.Equal(2, incoming[1].ChainId);
+ Assert.Equal(1, incoming[1].TransitionCount);
+ }
+
+ [Fact]
+ public void GetOutgoingChains_ReturnsSuccessorsSortedByCount()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ // Chain 5 leads to chain 8 (2 times) and chain 9 (1 time).
+ heatMap.RecordChainAccepted(5);
+ heatMap.RecordChainAccepted(8);
+ heatMap.ResetGenerationState();
+ heatMap.RecordChainAccepted(5);
+ heatMap.RecordChainAccepted(8);
+ heatMap.ResetGenerationState();
+ heatMap.RecordChainAccepted(5);
+ heatMap.RecordChainAccepted(9);
+
+ var outgoing = heatMap.GetOutgoingChains(5);
+ Assert.Equal(2, outgoing.Count);
+ Assert.Equal(8, outgoing[0].ChainId);
+ Assert.Equal(2, outgoing[0].TransitionCount);
+ Assert.Equal(9, outgoing[1].ChainId);
+ Assert.Equal(1, outgoing[1].TransitionCount);
+ }
+
+ [Fact]
+ public void GetHotPaths_FindsFrequentChainSequences()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ // Create a clear hot-path: 1 → 2 → 3 (each transition fires 5 times).
+ for (var i = 0; i < 5; i++)
+ {
+ heatMap.RecordChainAccepted(1);
+ heatMap.RecordChainAccepted(2);
+ heatMap.RecordChainAccepted(3);
+ heatMap.ResetGenerationState();
+ }
+
+ var hotPaths = heatMap.GetHotPaths(maxDepth: 5, maxResults: 10, minTransitions: 2);
+ Assert.True(hotPaths.Count > 0);
+
+ var mainPath = hotPaths[0];
+ Assert.True(mainPath.ChainSequence.Length >= 3);
+ Assert.Equal(1, mainPath.ChainSequence[0]);
+ Assert.Equal(2, mainPath.ChainSequence[1]);
+ Assert.Equal(3, mainPath.ChainSequence[2]);
+ Assert.Equal(5, mainPath.MinTransitionCount);
+ }
+
+ [Fact]
+ public void GetHotPaths_RespectsMinTransitionsThreshold()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ // Only 1 transition — below the default minTransitions=2.
+ heatMap.RecordChainAccepted(1);
+ heatMap.RecordChainAccepted(2);
+
+ var hotPaths = heatMap.GetHotPaths(minTransitions: 2);
+ Assert.Empty(hotPaths);
+ }
+
+ [Fact]
+ public void GetTopTokensByAcceptCount_ReturnsDescendingOrder()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordTokenAccepted(0, 3);
+ heatMap.RecordTokenAccepted(0, 7);
+ heatMap.RecordTokenAccepted(0, 7);
+ heatMap.RecordTokenAccepted(0, 7);
+ heatMap.RecordTokenAccepted(0, 5);
+ heatMap.RecordTokenAccepted(0, 5);
+
+ var top = heatMap.GetTopTokensByAcceptCount(maxResults: 3);
+ Assert.Equal(3, top.Count);
+ Assert.Equal(7, top[0].TokenId);
+ Assert.Equal(3, top[0].AcceptCount);
+ Assert.Equal(5, top[1].TokenId);
+ Assert.Equal(2, top[1].AcceptCount);
+ Assert.Equal(3, top[2].TokenId);
+ Assert.Equal(1, top[2].AcceptCount);
+ }
+
+ [Fact]
+ public void GetTopTokensByRecallRate_RespectsMinAttempts()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ // Token 3: 1 attempt, 1 accept (100% rate but below minAttempts=3).
+ heatMap.RecordChainAttempt(0, [3], 0);
+ heatMap.RecordTokenAccepted(0, 3);
+
+ // Token 7: 5 attempts, 4 accepts (80% rate, above minAttempts).
+ for (var i = 0; i < 5; i++)
+ {
+ heatMap.RecordChainAttempt(0, [7], 0);
+ }
+
+ for (var i = 0; i < 4; i++)
+ {
+ heatMap.RecordTokenAccepted(0, 7);
+ }
+
+ var top = heatMap.GetTopTokensByRecallRate(maxResults: 10, minAttempts: 3);
+ Assert.Single(top);
+ Assert.Equal(7, top[0].TokenId);
+ }
+
+ [Fact]
+ public void RankBucketsForCompaction_ReturnsWorstFirst()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([
+ new ChainBucket(0, [1, 2], 1.0f),
+ new ChainBucket(1, [3, 4], 1.0f)
+ ]);
+
+ // Chain 0: 10 attempts, 8 accepts (80%).
+ for (var i = 0; i < 10; i++)
+ {
+ heatMap.RecordChainAttempt(0, [1, 2], 1);
+ }
+
+ for (var i = 0; i < 8; i++)
+ {
+ heatMap.RecordChainAccepted(0);
+ }
+
+ // Chain 1: 10 attempts, 2 accepts (20%).
+ for (var i = 0; i < 10; i++)
+ {
+ heatMap.RecordChainAttempt(1, [3, 4], 1);
+ }
+
+ for (var i = 0; i < 2; i++)
+ {
+ heatMap.RecordChainAccepted(1);
+ }
+
+ var rankings = heatMap.RankBucketsForCompaction(table);
+ Assert.Equal(2, rankings.Count);
+ Assert.Equal(1, rankings[0].ChainId); // Worst first (20%).
+ Assert.Equal(0, rankings[1].ChainId); // Better (80%).
+ }
+
+ [Fact]
+ public void RankBucketsForCompaction_MarksHotPathChains()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([
+ new ChainBucket(0, [1, 2], 1.0f),
+ new ChainBucket(1, [3, 4], 1.0f),
+ new ChainBucket(2, [5, 6], 1.0f)
+ ]);
+
+ // Create a hot-path: 0 → 1 (5 transitions).
+ for (var i = 0; i < 5; i++)
+ {
+ heatMap.RecordChainAccepted(0);
+ heatMap.RecordChainAccepted(1);
+ heatMap.ResetGenerationState();
+ }
+
+ var rankings = heatMap.RankBucketsForCompaction(table);
+ var chain0 = rankings.Single(r => r.ChainId == 0);
+ var chain1 = rankings.Single(r => r.ChainId == 1);
+ var chain2 = rankings.Single(r => r.ChainId == 2);
+
+ Assert.True(chain0.OnHotPath);
+ Assert.True(chain1.OnHotPath);
+ Assert.False(chain2.OnHotPath);
+ }
+
+ [Fact]
+ public void IdentifyLowValueBuckets_ReturnsChainsBelowThreshold()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([
+ new ChainBucket(0, [1, 2], 1.0f),
+ new ChainBucket(1, [3, 4], 1.0f)
+ ]);
+
+ // Chain 0: 10 attempts, 1 accept (10% — below 50% threshold).
+ for (var i = 0; i < 10; i++)
+ {
+ heatMap.RecordChainAttempt(0, [1, 2], 1);
+ }
+
+ heatMap.RecordChainAccepted(0);
+
+ // Chain 1: 10 attempts, 8 accepts (80% — above threshold).
+ for (var i = 0; i < 10; i++)
+ {
+ heatMap.RecordChainAttempt(1, [3, 4], 1);
+ }
+
+ for (var i = 0; i < 8; i++)
+ {
+ heatMap.RecordChainAccepted(1);
+ }
+
+ var lowValue = heatMap.IdentifyLowValueBuckets(table, threshold: 0.5);
+ Assert.Contains((byte)0, lowValue);
+ Assert.DoesNotContain((byte)1, lowValue);
+ }
+
+ [Fact]
+ public void IdentifyLowValueBuckets_ExcludesHotPathChains()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([
+ new ChainBucket(0, [1, 2], 1.0f),
+ new ChainBucket(1, [3, 4], 1.0f)
+ ]);
+
+ // Chain 0: low recall but on a hot-path.
+ for (var i = 0; i < 10; i++)
+ {
+ heatMap.RecordChainAttempt(0, [1, 2], 1);
+ }
+
+ heatMap.RecordChainAccepted(0);
+
+ // Create hot-path: 0 → 1 (3 transitions).
+ for (var i = 0; i < 3; i++)
+ {
+ heatMap.RecordChainAccepted(0);
+ heatMap.RecordChainAccepted(1);
+ heatMap.ResetGenerationState();
+ }
+
+ var lowValue = heatMap.IdentifyLowValueBuckets(table, threshold: 0.5, minAttempts: 2);
+ Assert.DoesNotContain((byte)0, lowValue); // Protected by hot-path.
+ }
+
+ [Fact]
+ public void Reset_ClearsAllCountersAndTransitions()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ heatMap.RecordChainAttempt(0, [5, 10], 0);
+ heatMap.RecordTokenAccepted(0, 5);
+ heatMap.RecordChainAccepted(0);
+ heatMap.RecordChainAccepted(1);
+
+ heatMap.Reset();
+
+ Assert.Equal(0, heatMap.GetAttemptCount(5));
+ Assert.Equal(0, heatMap.GetAcceptCount(5));
+ Assert.Equal(0, heatMap.GetChainAttemptCount(0));
+ Assert.Equal(0, heatMap.GetChainAcceptCount(0));
+ Assert.Equal(0, heatMap.GetTransitionCount(0, 1));
+ }
+
+ [Fact]
+ public void MergeFrom_AddsCounts()
+ {
+ var a = new BucketRecallHeatMap(TestVocabSize);
+ var b = new BucketRecallHeatMap(TestVocabSize);
+
+ a.RecordChainAttempt(0, [5], 0);
+ a.RecordTokenAccepted(0, 5);
+ b.RecordChainAttempt(0, [5], 0);
+ b.RecordTokenAccepted(0, 5);
+
+ a.MergeFrom(b);
+
+ Assert.Equal(2, a.GetAttemptCount(5));
+ Assert.Equal(2, a.GetAcceptCount(5));
+ }
+
+ [Fact]
+ public void MergeFrom_AddsTransitionCounts()
+ {
+ var a = new BucketRecallHeatMap(TestVocabSize);
+ var b = new BucketRecallHeatMap(TestVocabSize);
+
+ a.RecordChainAccepted(1);
+ a.RecordChainAccepted(2);
+
+ b.RecordChainAccepted(1);
+ b.RecordChainAccepted(2);
+ b.ResetGenerationState();
+ b.RecordChainAccepted(1);
+ b.RecordChainAccepted(2);
+
+ a.MergeFrom(b);
+
+ Assert.Equal(3, a.GetTransitionCount(1, 2));
+ }
+
+ [Fact]
+ public void MergeFrom_ThrowsOnVocabSizeMismatch()
+ {
+ var a = new BucketRecallHeatMap(32);
+ var b = new BucketRecallHeatMap(64);
+
+ Assert.Throws(() => a.MergeFrom(b));
+ }
+
+ [Fact]
+ public void ExportCounters_RoundTripsViaFromCounters()
+ {
+ var original = new BucketRecallHeatMap(TestVocabSize);
+
+ original.RecordChainAttempt(0, [5, 10], 0);
+ original.RecordTokenAccepted(0, 5);
+ original.RecordChainAccepted(0);
+ original.RecordChainAccepted(1);
+
+ var counters = original.ExportCounters();
+ var restored = BucketRecallHeatMap.FromCounters(counters);
+
+ Assert.Equal(original.VocabSize, restored.VocabSize);
+ Assert.Equal(original.GetAttemptCount(5), restored.GetAttemptCount(5));
+ Assert.Equal(original.GetAttemptCount(10), restored.GetAttemptCount(10));
+ Assert.Equal(original.GetAcceptCount(5), restored.GetAcceptCount(5));
+ Assert.Equal(original.GetChainAttemptCount(0), restored.GetChainAttemptCount(0));
+ Assert.Equal(original.GetChainAcceptCount(0), restored.GetChainAcceptCount(0));
+ Assert.Equal(original.GetChainAcceptCount(1), restored.GetChainAcceptCount(1));
+ Assert.Equal(original.GetTransitionCount(0, 1), restored.GetTransitionCount(0, 1));
+ }
+}
diff --git a/tests/BitNetSharp.Tests/BucketRecallVisualizerTests.cs b/tests/BitNetSharp.Tests/BucketRecallVisualizerTests.cs
new file mode 100644
index 0000000..59f7ed8
--- /dev/null
+++ b/tests/BitNetSharp.Tests/BucketRecallVisualizerTests.cs
@@ -0,0 +1,157 @@
+using BitNetSharp.Core.Bucketing;
+
+namespace BitNetSharp.Tests;
+
+public sealed class BucketRecallVisualizerTests
+{
+ private const int TestVocabSize = 32;
+
+ private static string ResolveToken(int tokenId) => $"tok_{tokenId}";
+
+ [Fact]
+ public void RenderTokenHeatMap_StartsWithMermaidXyChartBlock()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ heatMap.RecordTokenAccepted(0, 5);
+
+ var result = BucketRecallVisualizer.RenderTokenHeatMap(heatMap, ResolveToken);
+
+ Assert.StartsWith("```mermaid", result, StringComparison.Ordinal);
+ Assert.Contains("xychart-beta", result, StringComparison.Ordinal);
+ Assert.EndsWith("```", result.TrimEnd(), StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void RenderTokenHeatMap_ContainsExpectedTokenLabels()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ heatMap.RecordTokenAccepted(0, 5);
+ heatMap.RecordTokenAccepted(0, 5);
+ heatMap.RecordTokenAccepted(0, 10);
+
+ var result = BucketRecallVisualizer.RenderTokenHeatMap(heatMap, ResolveToken);
+
+ Assert.Contains("tok_5", result, StringComparison.Ordinal);
+ Assert.Contains("tok_10", result, StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void RenderTokenHeatMap_ReturnsEmptyDiagramWhenNoData()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+
+ var result = BucketRecallVisualizer.RenderTokenHeatMap(heatMap, ResolveToken);
+
+ Assert.StartsWith("```mermaid", result, StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void RenderChainRecallChart_StartsWithMermaidXyChartBlock()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([new ChainBucket(0, [1, 2], 1.0f)]);
+ heatMap.RecordChainAttempt(0, [1, 2], 1);
+ heatMap.RecordChainAccepted(0);
+
+ var result = BucketRecallVisualizer.RenderChainRecallChart(heatMap, table, ResolveToken);
+
+ Assert.StartsWith("```mermaid", result, StringComparison.Ordinal);
+ Assert.Contains("xychart-beta", result, StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void RenderChainRecallChart_ContainsChainTokenLabels()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([new ChainBucket(0, [1, 2], 1.0f)]);
+ heatMap.RecordChainAttempt(0, [1, 2], 1);
+ heatMap.RecordChainAccepted(0);
+
+ var result = BucketRecallVisualizer.RenderChainRecallChart(heatMap, table, ResolveToken);
+
+ Assert.Contains("tok_1", result, StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void RenderHotPathDiagram_StartsWithMermaidFlowchart()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([
+ new ChainBucket(0, [1, 2], 1.0f),
+ new ChainBucket(1, [3, 4], 1.0f)
+ ]);
+
+ for (var i = 0; i < 5; i++)
+ {
+ heatMap.RecordChainAccepted(0);
+ heatMap.RecordChainAccepted(1);
+ heatMap.ResetGenerationState();
+ }
+
+ var result = BucketRecallVisualizer.RenderHotPathDiagram(heatMap, table, ResolveToken);
+
+ Assert.StartsWith("```mermaid", result, StringComparison.Ordinal);
+ Assert.Contains("flowchart LR", result, StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void RenderHotPathDiagram_ContainsEdgesWithTransitionCounts()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([
+ new ChainBucket(0, [1, 2], 1.0f),
+ new ChainBucket(1, [3, 4], 1.0f)
+ ]);
+
+ for (var i = 0; i < 3; i++)
+ {
+ heatMap.RecordChainAccepted(0);
+ heatMap.RecordChainAccepted(1);
+ heatMap.ResetGenerationState();
+ }
+
+ var result = BucketRecallVisualizer.RenderHotPathDiagram(heatMap, table, ResolveToken);
+
+ Assert.Contains("-->|3|", result, StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void RenderCompactionReport_StartsWithMermaidFlowchart()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([
+ new ChainBucket(0, [1, 2], 1.0f),
+ new ChainBucket(1, [3, 4], 1.0f)
+ ]);
+
+ for (var i = 0; i < 5; i++)
+ {
+ heatMap.RecordChainAttempt(0, [1, 2], 1);
+ }
+
+ heatMap.RecordChainAccepted(0);
+
+ var result = BucketRecallVisualizer.RenderCompactionReport(heatMap, table, ResolveToken);
+
+ Assert.StartsWith("```mermaid", result, StringComparison.Ordinal);
+ Assert.Contains("flowchart", result, StringComparison.Ordinal);
+ }
+
+ [Fact]
+ public void RenderCompactionReport_ColorsLowValueNodesRed()
+ {
+ var heatMap = new BucketRecallHeatMap(TestVocabSize);
+ var table = new ChainBucketTable([new ChainBucket(0, [1, 2], 1.0f)]);
+
+ for (var i = 0; i < 5; i++)
+ {
+ heatMap.RecordChainAttempt(0, [1, 2], 1);
+ }
+
+ heatMap.RecordChainAccepted(0);
+
+ var result = BucketRecallVisualizer.RenderCompactionReport(heatMap, table, ResolveToken, threshold: 0.5);
+
+ Assert.Contains("fill:#f44", result, StringComparison.Ordinal);
+ }
+}
diff --git a/tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs b/tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs
index a3e2df0..964415a 100644
--- a/tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs
+++ b/tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs
@@ -1,5 +1,6 @@
using BitNetSharp.App;
using BitNetSharp.Core;
+using BitNetSharp.Core.Bucketing;
namespace BitNetSharp.Tests;
@@ -211,6 +212,54 @@ public void BuiltInModelsPreserveTrainedResponsesAcrossCheckpointRoundTrips()
Assert.True(traditionalRoundTrip.ResponsesMatch);
}
+ [Fact]
+ public void ChainBucketRecallWithHeatMapEnabled_TracksTokenAcceptanceAndTransitions()
+ {
+ var examples = BitNetTrainingCorpus.CreateBenchmarkExamples();
+ var model = BitNetPaperModel.CreateForTrainingCorpus(
+ examples,
+ VerbosityLevel.Quiet,
+ enableChainBuckets: true);
+ model.Train(examples, epochs: 3);
+ model.MineAndLoadBuckets(examples);
+
+ Assert.NotNull(model.RecallHeatMap);
+
+ foreach (var example in examples)
+ {
+ model.GenerateResponse(example.Prompt, maxTokens: 8);
+ }
+
+ // After multiple generations with bucketing enabled, the heat map should have
+ // recorded at least some chain attempts.
+ var totalChainAttempts = 0L;
+ for (var i = 0; i < ChainBucketTable.MaxBuckets; i++)
+ {
+ totalChainAttempts += model.RecallHeatMap.GetChainAttemptCount((byte)i);
+ }
+
+ Assert.True(totalChainAttempts >= 0, "Heat map should be present and tracking (may be zero if no chains matched).");
+ }
+
+ [Fact]
+ public void ChainBucketRecallWithHeatMapDisabled_DoesNotAllocateHeatMap()
+ {
+ var examples = BitNetTrainingCorpus.CreateBenchmarkExamples();
+ var model = new BitNetPaperModel(
+ new BitNetOptions(
+ BitNetTrainingCorpus.CreateBenchmarkVocabulary(),
+ VerbosityLevel.Quiet,
+ EnableChainBuckets: true,
+ EnableRecallHeatMap: false));
+ model.MineAndLoadBuckets(examples);
+
+ Assert.Null(model.RecallHeatMap);
+
+ model.GenerateResponse("what is bitnet", maxTokens: 4);
+
+ Assert.Null(model.RecallHeatMap);
+ }
+
private static async Task WithBenchmarkOptionsAsync(HostedAgentBenchmarkOptions options, Func assertion)
{
var originalValue = Environment.GetEnvironmentVariable(HostedAgentBenchmarkOptions.EnvironmentVariableName);