From b687668e62d387fa489f02e9c0fd9b76e7eb20a5 Mon Sep 17 00:00:00 2001 From: Sharp Ninja Date: Tue, 7 Apr 2026 12:20:15 -0500 Subject: [PATCH 1/2] Add model weights to Core NuGet package and publish to GitHub Packages - Generate default BitNet model weights (.gguf) during CI build and include them in the BitNetSharp.Core NuGet package as content files - Configure both Azure Pipelines and GitHub Actions to publish NuGet packages to https://nuget.pkg.github.com/sharpninja/index.json - Reference McpServer variable library for GH_TOKEN authentication in Azure Pipelines - Add nuget.config with GitHub Packages source and package source mapping to prevent dependency confusion - Bump version to 0.6.0 Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/build.yml | 5 +++ .gitignore | 3 ++ .version | 2 +- Directory.Build.props | 2 +- GitVersion.yml | 2 +- azure-pipelines.yml | 40 +++++++++++--------- nuget.config | 15 ++++++++ src/BitNetSharp.Core/BitNetSharp.Core.csproj | 8 ++++ 8 files changed, 57 insertions(+), 20 deletions(-) create mode 100644 nuget.config diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e68cabe..60585bd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -45,6 +45,11 @@ jobs: - name: Test run: dotnet test BitNet-b1.58-Sharp.slnx --configuration Release --no-build --no-restore --filter "Category!=SlowLane" + - name: Generate default model weights + run: | + mkdir -p "${{ github.workspace }}/src/BitNetSharp.Core/Data/Models" + dotnet run --framework net9.0 --project "${{ github.workspace }}/src/BitNetSharp.App/BitNetSharp.App.csproj" --configuration Release --no-build -- export --output="${{ github.workspace }}/src/BitNetSharp.Core/Data/Models/bitnet-b1.58-default.gguf" + - name: Pack BitNetSharp.Core run: dotnet pack "${{ github.workspace }}/src/BitNetSharp.Core/BitNetSharp.Core.csproj" --configuration Release --no-build --no-restore -p:PackageVersion=${{ steps.gitversion.outputs.semVer }} --output "${{ github.workspace }}/artifacts/packages/core" diff --git a/.gitignore b/.gitignore index 855a6ef..e8b65a7 100644 --- a/.gitignore +++ b/.gitignore @@ -420,3 +420,6 @@ FodyWeavers.xsd AGENTS-README-FIRST.yaml .mcpServer/ + +# Generated model weights (produced during CI builds) +*.gguf diff --git a/.version b/.version index 6339605..a918a2a 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -0.1.0-78 +0.6.0 diff --git a/Directory.Build.props b/Directory.Build.props index 2addca1..2d12427 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,5 +1,5 @@ - 0.1.0 + 0.6.0 diff --git a/GitVersion.yml b/GitVersion.yml index 3f5c2ea..2a18fb4 100644 --- a/GitVersion.yml +++ b/GitVersion.yml @@ -1,4 +1,4 @@ -next-version: 0.1.0 +next-version: 0.6.0 mode: ContinuousDelivery tag-prefix: '[vV]?' strategies: diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 8ddfd10..ab349d7 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -13,14 +13,21 @@ pr: - "*" variables: - BuildConfiguration: Release - SolutionFile: BitNet-b1.58-Sharp.slnx - CoreProject: src/BitNetSharp.Core/BitNetSharp.Core.csproj - AppProject: src/BitNetSharp.App/BitNetSharp.App.csproj - CoreArtifactName: BitNetSharp.Core-nuget - ToolArtifactName: BitNetSharp.App-dotnet-tool - NuGetFeedUrl: "" - NuGetApiKey: "" + - group: McpServer + - name: BuildConfiguration + value: Release + - name: SolutionFile + value: BitNet-b1.58-Sharp.slnx + - name: CoreProject + value: src/BitNetSharp.Core/BitNetSharp.Core.csproj + - name: AppProject + value: src/BitNetSharp.App/BitNetSharp.App.csproj + - name: CoreArtifactName + value: BitNetSharp.Core-nuget + - name: ToolArtifactName + value: BitNetSharp.App-dotnet-tool + - name: NuGetFeedUrl + value: "https://nuget.pkg.github.com/sharpninja/index.json" stages: - stage: build @@ -70,6 +77,12 @@ stages: - script: dotnet test "$(SolutionFile)" --configuration "$(BuildConfiguration)" --no-build --no-restore --filter "Category!=SlowLane" displayName: Test + - pwsh: | + $modelsDirectory = "$(Build.SourcesDirectory)/src/BitNetSharp.Core/Data/Models" + New-Item -ItemType Directory -Force -Path $modelsDirectory | Out-Null + dotnet run --framework net9.0 --project "$(Build.SourcesDirectory)/$(AppProject)" --configuration "$(BuildConfiguration)" --no-build -- export --output="$modelsDirectory/bitnet-b1.58-default.gguf" + displayName: Generate default model weights + - pwsh: | $outputDirectory = "$(Build.ArtifactStagingDirectory)/packages/core" New-Item -ItemType Directory -Force -Path $outputDirectory | Out-Null @@ -155,16 +168,9 @@ stages: foreach ($package in $packages) { - if ([string]::IsNullOrWhiteSpace($env:NUGET_API_KEY)) - { - dotnet nuget push $package.FullName --source $env:NUGET_SOURCE --api-key AzureArtifacts --skip-duplicate - } - else - { - dotnet nuget push $package.FullName --source $env:NUGET_SOURCE --api-key $env:NUGET_API_KEY --skip-duplicate - } + dotnet nuget push $package.FullName --source $env:NUGET_SOURCE --api-key $env:NUGET_API_KEY --skip-duplicate } displayName: Publish packages to NuGet feed env: NUGET_SOURCE: $(NuGetFeedUrl) - NUGET_API_KEY: $(NuGetApiKey) + NUGET_API_KEY: $(GH_TOKEN) diff --git a/nuget.config b/nuget.config new file mode 100644 index 0000000..b2dfa8d --- /dev/null +++ b/nuget.config @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/src/BitNetSharp.Core/BitNetSharp.Core.csproj b/src/BitNetSharp.Core/BitNetSharp.Core.csproj index d27ba3b..6da77b1 100644 --- a/src/BitNetSharp.Core/BitNetSharp.Core.csproj +++ b/src/BitNetSharp.Core/BitNetSharp.Core.csproj @@ -17,4 +17,12 @@ + + + + From 2c8bd6646122323f096bcce0fd0f3165e46d87c0 Mon Sep 17 00:00:00 2001 From: Sharp Ninja Date: Wed, 8 Apr 2026 10:41:55 -0500 Subject: [PATCH 2/2] Add bucket recall heat map for compaction-aware speculative decoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce BucketRecallHeatMap that tracks per-token and per-chain attempt/accept counts during speculative decoding, with a 256x256 chain transition adjacency matrix for hot-path detection. The heat map identifies frequently-traversed chain sequences and uses them to rank buckets for compaction — chains on hot-paths receive a preservation bonus while isolated low-recall chains are pruned first. Key components: - BucketRecallHeatMap: O(1) recording via long[] arrays, hot-path detection via greedy walk on transition matrix, compaction ranking - BucketRecallHeatMapSerializer: binary sidecar (BRHM v1, CRC32) - BucketRecallVisualizer: Mermaid xychart-beta and flowchart output - Configurable via BitNetOptions.EnableRecallHeatMap (default: true) - Persisted alongside model checkpoints and GGUF files as sidecar 41 new tests covering all public methods, serialization round-trips, and benchmark integration (heat map enabled vs disabled). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/BitNetSharp.Core/BitNetOptions.cs | 3 +- src/BitNetSharp.Core/BitNetPaperCheckpoint.cs | 40 +- src/BitNetSharp.Core/BitNetPaperGguf.cs | 46 +- src/BitNetSharp.Core/BitNetPaperModel.cs | 20 + .../Bucketing/BucketRecallHeatMap.cs | 438 ++++++++++++++++ .../BucketRecallHeatMapSerializer.cs | 192 +++++++ .../Bucketing/BucketRecallVisualizer.cs | 212 ++++++++ .../Serialization/BitNetPaperModelSnapshot.cs | 9 +- .../BucketRecallHeatMapSerializerTests.cs | 152 ++++++ .../BucketRecallHeatMapTests.cs | 469 ++++++++++++++++++ .../BucketRecallVisualizerTests.cs | 157 ++++++ .../HostedAgentBenchmarksExecutionTests.cs | 49 ++ 12 files changed, 1779 insertions(+), 8 deletions(-) create mode 100644 src/BitNetSharp.Core/Bucketing/BucketRecallHeatMap.cs create mode 100644 src/BitNetSharp.Core/Bucketing/BucketRecallHeatMapSerializer.cs create mode 100644 src/BitNetSharp.Core/Bucketing/BucketRecallVisualizer.cs create mode 100644 tests/BitNetSharp.Tests/BucketRecallHeatMapSerializerTests.cs create mode 100644 tests/BitNetSharp.Tests/BucketRecallHeatMapTests.cs create mode 100644 tests/BitNetSharp.Tests/BucketRecallVisualizerTests.cs diff --git a/src/BitNetSharp.Core/BitNetOptions.cs b/src/BitNetSharp.Core/BitNetOptions.cs index ae3322f..5008931 100644 --- a/src/BitNetSharp.Core/BitNetOptions.cs +++ b/src/BitNetSharp.Core/BitNetOptions.cs @@ -7,4 +7,5 @@ public sealed record BitNetOptions( string PrimaryLanguage = "en-US", bool EnableChainBuckets = false, bool EnableSequenceCompression = false, - double ChainBucketAcceptanceThreshold = 0.85d); + double ChainBucketAcceptanceThreshold = 0.85d, + bool EnableRecallHeatMap = true); diff --git a/src/BitNetSharp.Core/BitNetPaperCheckpoint.cs b/src/BitNetSharp.Core/BitNetPaperCheckpoint.cs index 783c5bd..002cf89 100644 --- a/src/BitNetSharp.Core/BitNetPaperCheckpoint.cs +++ b/src/BitNetSharp.Core/BitNetPaperCheckpoint.cs @@ -25,12 +25,14 @@ internal sealed record BitNetPaperCheckpointDocument( string PrimaryLanguage, bool EnableChainBuckets, bool EnableSequenceCompression, - double ChainBucketAcceptanceThreshold); + double ChainBucketAcceptanceThreshold, + bool EnableRecallHeatMap = true); public static class BitNetPaperCheckpoint { private const string FormatName = "bitnet-b1.58-sharp.repository-checkpoint.v1"; private const string BucketSidecarFileName = "chain-buckets.bin"; + private const string HeatMapSidecarFileName = "recall-heatmap.bin"; public static void Save(BitNetPaperModel model, string path) { @@ -59,9 +61,11 @@ public static void Save(BitNetPaperModel model, string path) snapshot.PrimaryLanguage, snapshot.EnableChainBuckets, snapshot.EnableSequenceCompression, - snapshot.ChainBucketAcceptanceThreshold); + snapshot.ChainBucketAcceptanceThreshold, + model.Options.EnableRecallHeatMap); File.WriteAllText(path, JsonSerializer.Serialize(document, new JsonSerializerOptions { WriteIndented = true })); SaveBucketSidecar(model.BucketTable, GetBucketSidecarPath(path)); + SaveHeatMapSidecar(model.RecallHeatMap, GetHeatMapSidecarPath(path)); } public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = VerbosityLevel.Normal) @@ -86,7 +90,8 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb document.PrimaryLanguage, document.EnableChainBuckets, document.EnableSequenceCompression, - acceptanceThreshold), + acceptanceThreshold, + document.EnableRecallHeatMap), document.Config, document.BootstrapSeed); var baselineSnapshot = BitNetPaperModelSnapshot.Capture(baselineModel); @@ -119,6 +124,12 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb model.LoadBucketTable(LoadBucketSidecar(bucketSidecarPath)); } + var heatMapSidecarPath = GetHeatMapSidecarPath(path); + if (model.RecallHeatMap is not null && File.Exists(heatMapSidecarPath)) + { + model.RecallHeatMap.MergeFrom(BucketRecallHeatMapSerializer.Load(heatMapSidecarPath)); + } + return model; } @@ -175,6 +186,29 @@ private static void SaveBucketSidecar(ChainBucketTable? bucketTable, string path private static ChainBucketTable LoadBucketSidecar(string path) => ChainBucketTableBinarySerializer.Load(path); + private static string GetHeatMapSidecarPath(string checkpointPath) + { + var directory = Path.GetDirectoryName(checkpointPath); + return string.IsNullOrWhiteSpace(directory) + ? HeatMapSidecarFileName + : Path.Combine(directory, HeatMapSidecarFileName); + } + + private static void SaveHeatMapSidecar(BucketRecallHeatMap? heatMap, string path) + { + if (heatMap is null) + { + if (File.Exists(path)) + { + File.Delete(path); + } + + return; + } + + BucketRecallHeatMapSerializer.Save(heatMap, path); + } + private static float[][] ToJagged(float[,] matrix) { var result = new float[matrix.GetLength(0)][]; diff --git a/src/BitNetSharp.Core/BitNetPaperGguf.cs b/src/BitNetSharp.Core/BitNetPaperGguf.cs index 916868e..2b0048e 100644 --- a/src/BitNetSharp.Core/BitNetPaperGguf.cs +++ b/src/BitNetSharp.Core/BitNetPaperGguf.cs @@ -25,6 +25,7 @@ public static void Save(BitNetPaperModel model, string path) GgufWriter.Write(path, CreateMetadata(snapshot), CreateTensors(snapshot)); SaveBucketSidecar(model.BucketTable, GetBucketSidecarPath(path)); + SaveHeatMapSidecar(model.RecallHeatMap, GetHeatMapSidecarPath(path)); } public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = VerbosityLevel.Normal) @@ -78,7 +79,8 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb transformerProjectionWeights, normScales, ReadMatrix(tensors[OutputTensorName], config.VocabSize, config.Dimension), - DeserializeMemorizedResponses(GetRequiredString(document.Metadata, MemorizedResponsesMetadataKey))); + DeserializeMemorizedResponses(GetRequiredString(document.Metadata, MemorizedResponsesMetadataKey)), + ReadOptionalBool(document.Metadata, "bitnetsharp.enable_recall_heat_map", defaultValue: true)); var model = snapshot.Restore(verbosity); var bucketSidecarPath = GetBucketSidecarPath(path); @@ -87,6 +89,12 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb model.LoadBucketTable(ChainBucketTableBinarySerializer.Load(bucketSidecarPath)); } + var heatMapSidecarPath = GetHeatMapSidecarPath(path); + if (model.RecallHeatMap is not null && File.Exists(heatMapSidecarPath)) + { + model.RecallHeatMap.MergeFrom(BucketRecallHeatMapSerializer.Load(heatMapSidecarPath)); + } + return model; } @@ -106,6 +114,7 @@ private static Dictionary CreateMetadata(BitNetPaperModelSnapsho ["bitnetsharp.primary_language"] = snapshot.PrimaryLanguage, ["bitnetsharp.enable_chain_buckets"] = snapshot.EnableChainBuckets, ["bitnetsharp.enable_sequence_compression"] = snapshot.EnableSequenceCompression, + ["bitnetsharp.enable_recall_heat_map"] = snapshot.EnableRecallHeatMap, ["bitnetsharp.chain_bucket_acceptance_threshold"] = snapshot.ChainBucketAcceptanceThreshold, ["bitnetsharp.config.vocab_size"] = snapshot.Config.VocabSize, ["bitnetsharp.config.dimension"] = snapshot.Config.Dimension, @@ -310,6 +319,31 @@ private static float[] FlattenMatrix(float[,] matrix) private static string GetFeedForwardProjectionTensorName(int layer, string suffix) => $"blk.{layer}.ffn_{suffix}.weight"; + private static string GetHeatMapSidecarPath(string ggufPath) + { + var directory = Path.GetDirectoryName(ggufPath); + var baseName = Path.GetFileNameWithoutExtension(ggufPath); + var fileName = $"{baseName}.recall-heatmap.bin"; + return string.IsNullOrWhiteSpace(directory) + ? fileName + : Path.Combine(directory, fileName); + } + + private static void SaveHeatMapSidecar(BucketRecallHeatMap? heatMap, string path) + { + if (heatMap is null) + { + if (File.Exists(path)) + { + File.Delete(path); + } + + return; + } + + BucketRecallHeatMapSerializer.Save(heatMap, path); + } + private static string GetBucketSidecarPath(string ggufPath) { var directory = Path.GetDirectoryName(ggufPath); @@ -354,6 +388,16 @@ private static string GetRequiredString(IReadOnlyDictionary meta return text; } + private static bool ReadOptionalBool(IReadOnlyDictionary metadata, string key, bool defaultValue) + { + if (metadata.TryGetValue(key, out var value) && value is bool boolean) + { + return boolean; + } + + return defaultValue; + } + private static bool GetRequiredBool(IReadOnlyDictionary metadata, string key) { if (!metadata.TryGetValue(key, out var value) || value is not bool boolean) diff --git a/src/BitNetSharp.Core/BitNetPaperModel.cs b/src/BitNetSharp.Core/BitNetPaperModel.cs index 151d19d..4aff55f 100644 --- a/src/BitNetSharp.Core/BitNetPaperModel.cs +++ b/src/BitNetSharp.Core/BitNetPaperModel.cs @@ -24,6 +24,7 @@ public sealed class BitNetPaperModel private readonly string[] _idToToken; private readonly BitNetTokenizer _tokenizer; private readonly object _gate = new(); + private BucketRecallHeatMap? _recallHeatMap; public BitNetPaperModel(IEnumerable trainingExamples, VerbosityLevel verbosity = VerbosityLevel.Normal, BitNetConfig? config = null, int seed = 42) : this( @@ -102,12 +103,21 @@ .. options.Vocabulary /// public ChainBucketTable? BucketTable { get; private set; } + /// + /// Optional recall heat map that tracks per-token and per-chain accept/attempt counts + /// during speculative decoding. Populated when a bucket table is loaded and + /// is set. + /// + public BucketRecallHeatMap? RecallHeatMap => _recallHeatMap; + public string ModelId => "bitnet-b1.58-sharp"; public BitNetTokenizer Tokenizer => _tokenizer; public long EstimateResidentParameterBytes() => Transformer.EstimateResidentParameterBytes(); + public string GetTokenString(int tokenId) => _idToToken[tokenId]; + /// /// Mines chain buckets from the provided training examples using the model's tokenizer, /// builds a , attaches it to this model, and returns it. @@ -141,6 +151,11 @@ public void LoadBucketTable(ChainBucketTable table) { ArgumentNullException.ThrowIfNull(table); BucketTable = table; + + if (Options.EnableRecallHeatMap) + { + _recallHeatMap = new BucketRecallHeatMap(Config.VocabSize); + } } public static BitNetPaperModel CreateDefault( @@ -194,6 +209,8 @@ public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = n { lock (_gate) { + _recallHeatMap?.ResetGenerationState(); + var diagnostics = new List(); var contextTokenIds = TokenizeToIds(prompt).ToList(); var generatedTokenIds = new List(); @@ -293,6 +310,7 @@ public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = n if (matchedPrefixLen > 0) { attemptedChains++; + _recallHeatMap?.RecordChainAttempt(chain.ChainId, chain.TokenIds, matchedPrefixLen); var acceptedTokensForChain = 0; for (var ci = matchedPrefixLen; ci < chain.TokenIds.Length && step < maxGeneratedTokens - 1; ci++) { @@ -322,6 +340,7 @@ public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = n step++; acceptedTokensForChain++; acceptedChainTokens++; + _recallHeatMap?.RecordTokenAccepted(chain.ChainId, speculativeId); if (Options.Verbosity == VerbosityLevel.Verbose) { @@ -333,6 +352,7 @@ public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = n if (acceptedTokensForChain > 0) { acceptedChains++; + _recallHeatMap?.RecordChainAccepted(chain.ChainId); } } } diff --git a/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMap.cs b/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMap.cs new file mode 100644 index 0000000..43ee462 --- /dev/null +++ b/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMap.cs @@ -0,0 +1,438 @@ +namespace BitNetSharp.Core.Bucketing; + +public sealed record HeatMapCounters( + int VocabSize, + long[] AttemptCounts, + long[] AcceptCounts, + long[] ChainAttemptCounts, + long[] ChainAcceptCounts, + long[,] ChainTransitions); + +public sealed record TokenRecallEntry(int TokenId, long AttemptCount, long AcceptCount, double RecallRate); + +public sealed record BucketRecallRanking( + byte ChainId, + double AggregateRecallRate, + long TotalAcceptCount, + long TotalAttemptCount, + bool OnHotPath); + +public sealed record ChainTransitionEntry(byte ChainId, long TransitionCount); + +public sealed record ChainHotPath(byte[] ChainSequence, long MinTransitionCount); + +public sealed class BucketRecallHeatMap +{ + private readonly long[] _attemptCounts; + private readonly long[] _acceptCounts; + private readonly long[] _chainAttemptCounts; + private readonly long[] _chainAcceptCounts; + private readonly long[,] _chainTransitions; + private byte? _lastAcceptedChainId; + + public int VocabSize { get; } + + public BucketRecallHeatMap(int vocabSize) + { + ArgumentOutOfRangeException.ThrowIfNegativeOrZero(vocabSize); + + VocabSize = vocabSize; + _attemptCounts = new long[vocabSize]; + _acceptCounts = new long[vocabSize]; + _chainAttemptCounts = new long[ChainBucketTable.MaxBuckets]; + _chainAcceptCounts = new long[ChainBucketTable.MaxBuckets]; + _chainTransitions = new long[ChainBucketTable.MaxBuckets, ChainBucketTable.MaxBuckets]; + } + + public void RecordChainAttempt(byte chainId, int[] tokenIds, int speculativeStartIndex) + { + ArgumentNullException.ThrowIfNull(tokenIds); + + _chainAttemptCounts[chainId]++; + for (var i = speculativeStartIndex; i < tokenIds.Length; i++) + { + var tokenId = tokenIds[i]; + if ((uint)tokenId < (uint)VocabSize) + { + _attemptCounts[tokenId]++; + } + } + } + + public void RecordTokenAccepted(byte chainId, int tokenId) + { + if ((uint)tokenId < (uint)VocabSize) + { + _acceptCounts[tokenId]++; + } + } + + public void RecordChainAccepted(byte chainId) + { + _chainAcceptCounts[chainId]++; + if (_lastAcceptedChainId.HasValue) + { + _chainTransitions[_lastAcceptedChainId.Value, chainId]++; + } + + _lastAcceptedChainId = chainId; + } + + public void ResetGenerationState() + { + _lastAcceptedChainId = null; + } + + public long GetAttemptCount(int tokenId) => + (uint)tokenId < (uint)VocabSize ? _attemptCounts[tokenId] : 0L; + + public long GetAcceptCount(int tokenId) => + (uint)tokenId < (uint)VocabSize ? _acceptCounts[tokenId] : 0L; + + public double GetTokenRecallRate(int tokenId) + { + var attempts = GetAttemptCount(tokenId); + return attempts == 0L ? 0d : (double)GetAcceptCount(tokenId) / attempts; + } + + public long GetChainAttemptCount(byte chainId) => _chainAttemptCounts[chainId]; + + public long GetChainAcceptCount(byte chainId) => _chainAcceptCounts[chainId]; + + public double GetChainRecallRate(byte chainId) + { + var attempts = _chainAttemptCounts[chainId]; + return attempts == 0L ? 0d : (double)_chainAcceptCounts[chainId] / attempts; + } + + public long GetTransitionCount(byte fromChainId, byte toChainId) => + _chainTransitions[fromChainId, toChainId]; + + public IReadOnlyList GetIncomingChains(byte chainId, int maxResults = 10) + { + var entries = new List(); + for (var from = 0; from < ChainBucketTable.MaxBuckets; from++) + { + var count = _chainTransitions[from, chainId]; + if (count > 0) + { + entries.Add(new ChainTransitionEntry((byte)from, count)); + } + } + + return entries + .OrderByDescending(static e => e.TransitionCount) + .Take(maxResults) + .ToArray(); + } + + public IReadOnlyList GetOutgoingChains(byte chainId, int maxResults = 10) + { + var entries = new List(); + for (var to = 0; to < ChainBucketTable.MaxBuckets; to++) + { + var count = _chainTransitions[chainId, to]; + if (count > 0) + { + entries.Add(new ChainTransitionEntry((byte)to, count)); + } + } + + return entries + .OrderByDescending(static e => e.TransitionCount) + .Take(maxResults) + .ToArray(); + } + + public IReadOnlyList GetHotPaths(int maxDepth = 5, int maxResults = 10, long minTransitions = 2) + { + var paths = new List(); + var visited = new HashSet(); + + for (var start = 0; start < ChainBucketTable.MaxBuckets; start++) + { + if (_chainAcceptCounts[start] == 0) + { + continue; + } + + visited.Clear(); + var sequence = new List { (byte)start }; + visited.Add((byte)start); + var bottleneck = long.MaxValue; + + var current = (byte)start; + for (var depth = 1; depth < maxDepth; depth++) + { + var bestNext = -1; + var bestCount = 0L; + for (var next = 0; next < ChainBucketTable.MaxBuckets; next++) + { + var count = _chainTransitions[current, next]; + if (count >= minTransitions && count > bestCount && !visited.Contains((byte)next)) + { + bestNext = next; + bestCount = count; + } + } + + if (bestNext < 0) + { + break; + } + + sequence.Add((byte)bestNext); + visited.Add((byte)bestNext); + bottleneck = Math.Min(bottleneck, bestCount); + current = (byte)bestNext; + } + + if (sequence.Count >= 2) + { + paths.Add(new ChainHotPath(sequence.ToArray(), bottleneck)); + } + } + + // Remove sub-paths: if path A is a prefix of path B, keep only B. + var deduplicated = new List(); + var sortedByLength = paths.OrderByDescending(static p => p.ChainSequence.Length).ToArray(); + foreach (var path in sortedByLength) + { + var isSubPath = false; + foreach (var longer in deduplicated) + { + if (IsSubsequence(path.ChainSequence, longer.ChainSequence)) + { + isSubPath = true; + break; + } + } + + if (!isSubPath) + { + deduplicated.Add(path); + } + } + + return deduplicated + .OrderByDescending(static p => p.MinTransitionCount) + .Take(maxResults) + .ToArray(); + } + + public IReadOnlyList GetTopTokensByAcceptCount(int maxResults = 20) + { + var entries = new List(); + for (var i = 0; i < VocabSize; i++) + { + if (_acceptCounts[i] > 0) + { + entries.Add(new TokenRecallEntry(i, _attemptCounts[i], _acceptCounts[i], GetTokenRecallRate(i))); + } + } + + return entries + .OrderByDescending(static e => e.AcceptCount) + .Take(maxResults) + .ToArray(); + } + + public IReadOnlyList GetTopTokensByRecallRate(int maxResults = 20, int minAttempts = 1) + { + var entries = new List(); + for (var i = 0; i < VocabSize; i++) + { + if (_attemptCounts[i] >= minAttempts) + { + entries.Add(new TokenRecallEntry(i, _attemptCounts[i], _acceptCounts[i], GetTokenRecallRate(i))); + } + } + + return entries + .OrderByDescending(static e => e.RecallRate) + .Take(maxResults) + .ToArray(); + } + + public IReadOnlyList RankBucketsForCompaction(ChainBucketTable table) + { + ArgumentNullException.ThrowIfNull(table); + + var hotPathChainIds = GetHotPathChainIds(); + var rankings = new List(); + + foreach (var bucket in table.Buckets) + { + var attempts = _chainAttemptCounts[bucket.ChainId]; + var accepts = _chainAcceptCounts[bucket.ChainId]; + var recallRate = attempts == 0L ? 0d : (double)accepts / attempts; + var onHotPath = hotPathChainIds.Contains(bucket.ChainId); + + rankings.Add(new BucketRecallRanking( + bucket.ChainId, + recallRate, + accepts, + attempts, + onHotPath)); + } + + // Sort worst-first: non-hot-path chains with low recall come first (pruning candidates). + return rankings + .OrderBy(static r => r.OnHotPath) + .ThenBy(static r => r.AggregateRecallRate) + .ThenBy(static r => r.TotalAcceptCount) + .ToArray(); + } + + public IReadOnlySet IdentifyLowValueBuckets(ChainBucketTable table, double threshold = 0.5, int minAttempts = 2) + { + ArgumentNullException.ThrowIfNull(table); + + var hotPathChainIds = GetHotPathChainIds(); + var lowValue = new HashSet(); + + foreach (var bucket in table.Buckets) + { + if (hotPathChainIds.Contains(bucket.ChainId)) + { + continue; + } + + var attempts = _chainAttemptCounts[bucket.ChainId]; + if (attempts < minAttempts) + { + lowValue.Add(bucket.ChainId); + continue; + } + + var recallRate = (double)_chainAcceptCounts[bucket.ChainId] / attempts; + if (recallRate < threshold) + { + lowValue.Add(bucket.ChainId); + } + } + + return lowValue; + } + + public void Reset() + { + Array.Clear(_attemptCounts); + Array.Clear(_acceptCounts); + Array.Clear(_chainAttemptCounts); + Array.Clear(_chainAcceptCounts); + Array.Clear(_chainTransitions); + _lastAcceptedChainId = null; + } + + public void MergeFrom(BucketRecallHeatMap other) + { + ArgumentNullException.ThrowIfNull(other); + if (other.VocabSize != VocabSize) + { + throw new ArgumentException( + $"Cannot merge heat maps with different vocab sizes ({VocabSize} vs {other.VocabSize}).", + nameof(other)); + } + + for (var i = 0; i < VocabSize; i++) + { + _attemptCounts[i] += other._attemptCounts[i]; + _acceptCounts[i] += other._acceptCounts[i]; + } + + for (var i = 0; i < ChainBucketTable.MaxBuckets; i++) + { + _chainAttemptCounts[i] += other._chainAttemptCounts[i]; + _chainAcceptCounts[i] += other._chainAcceptCounts[i]; + for (var j = 0; j < ChainBucketTable.MaxBuckets; j++) + { + _chainTransitions[i, j] += other._chainTransitions[i, j]; + } + } + } + + public HeatMapCounters ExportCounters() + { + var attemptCounts = (long[])_attemptCounts.Clone(); + var acceptCounts = (long[])_acceptCounts.Clone(); + var chainAttemptCounts = (long[])_chainAttemptCounts.Clone(); + var chainAcceptCounts = (long[])_chainAcceptCounts.Clone(); + var chainTransitions = (long[,])_chainTransitions.Clone(); + return new HeatMapCounters(VocabSize, attemptCounts, acceptCounts, chainAttemptCounts, chainAcceptCounts, chainTransitions); + } + + public static BucketRecallHeatMap FromCounters(HeatMapCounters counters) + { + ArgumentNullException.ThrowIfNull(counters); + + if (counters.AttemptCounts.Length != counters.VocabSize + || counters.AcceptCounts.Length != counters.VocabSize) + { + throw new ArgumentException("Token counter arrays must match VocabSize.", nameof(counters)); + } + + if (counters.ChainAttemptCounts.Length != ChainBucketTable.MaxBuckets + || counters.ChainAcceptCounts.Length != ChainBucketTable.MaxBuckets) + { + throw new ArgumentException($"Chain counter arrays must have length {ChainBucketTable.MaxBuckets}.", nameof(counters)); + } + + if (counters.ChainTransitions.GetLength(0) != ChainBucketTable.MaxBuckets + || counters.ChainTransitions.GetLength(1) != ChainBucketTable.MaxBuckets) + { + throw new ArgumentException($"Chain transitions matrix must be {ChainBucketTable.MaxBuckets}x{ChainBucketTable.MaxBuckets}.", nameof(counters)); + } + + var heatMap = new BucketRecallHeatMap(counters.VocabSize); + Array.Copy(counters.AttemptCounts, heatMap._attemptCounts, counters.VocabSize); + Array.Copy(counters.AcceptCounts, heatMap._acceptCounts, counters.VocabSize); + Array.Copy(counters.ChainAttemptCounts, heatMap._chainAttemptCounts, ChainBucketTable.MaxBuckets); + Array.Copy(counters.ChainAcceptCounts, heatMap._chainAcceptCounts, ChainBucketTable.MaxBuckets); + Array.Copy(counters.ChainTransitions, heatMap._chainTransitions, ChainBucketTable.MaxBuckets * ChainBucketTable.MaxBuckets); + return heatMap; + } + + private HashSet GetHotPathChainIds() + { + var hotPaths = GetHotPaths(); + var chainIds = new HashSet(); + foreach (var path in hotPaths) + { + foreach (var chainId in path.ChainSequence) + { + chainIds.Add(chainId); + } + } + + return chainIds; + } + + private static bool IsSubsequence(byte[] candidate, byte[] longer) + { + if (candidate.Length >= longer.Length) + { + return false; + } + + for (var start = 0; start <= longer.Length - candidate.Length; start++) + { + var match = true; + for (var i = 0; i < candidate.Length; i++) + { + if (longer[start + i] != candidate[i]) + { + match = false; + break; + } + } + + if (match) + { + return true; + } + } + + return false; + } +} diff --git a/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMapSerializer.cs b/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMapSerializer.cs new file mode 100644 index 0000000..9d97468 --- /dev/null +++ b/src/BitNetSharp.Core/Bucketing/BucketRecallHeatMapSerializer.cs @@ -0,0 +1,192 @@ +using System.Buffers.Binary; +using System.Text; + +namespace BitNetSharp.Core.Bucketing; + +/// +/// Provides a compact binary serializer for sidecar persistence. +/// The wire format is little-endian and follows the repository's recall-heatmap.bin v1 layout. +/// +public static class BucketRecallHeatMapSerializer +{ + private static readonly byte[] MagicHeader = Encoding.ASCII.GetBytes("BRHM"); + private const ushort FormatVersion = 1; + private const int HeaderLength = 12; + private const int FooterLength = 4; + + /// Saves a recall heat map to a binary file. + public static void Save(BucketRecallHeatMap heatMap, string path) + { + ArgumentNullException.ThrowIfNull(heatMap); + ArgumentException.ThrowIfNullOrWhiteSpace(path); + + var directory = Path.GetDirectoryName(path); + if (!string.IsNullOrWhiteSpace(directory)) + { + Directory.CreateDirectory(directory); + } + + using var stream = File.Create(path); + Serialize(heatMap, stream); + } + + /// Loads a recall heat map from a binary file. + public static BucketRecallHeatMap Load(string path) + { + ArgumentException.ThrowIfNullOrWhiteSpace(path); + + using var stream = File.OpenRead(path); + return Deserialize(stream); + } + + /// Writes a recall heat map to a stream using the binary sidecar format. + public static void Serialize(BucketRecallHeatMap heatMap, Stream destination) + { + ArgumentNullException.ThrowIfNull(heatMap); + ArgumentNullException.ThrowIfNull(destination); + + var counters = heatMap.ExportCounters(); + + using var payload = new MemoryStream(); + using (var writer = new BinaryWriter(payload, Encoding.ASCII, leaveOpen: true)) + { + writer.Write(MagicHeader); + writer.Write(FormatVersion); + writer.Write((ushort)0); + writer.Write(counters.VocabSize); + + for (var i = 0; i < counters.VocabSize; i++) + { + writer.Write(counters.AttemptCounts[i]); + } + + for (var i = 0; i < counters.VocabSize; i++) + { + writer.Write(counters.AcceptCounts[i]); + } + + for (var i = 0; i < ChainBucketTable.MaxBuckets; i++) + { + writer.Write(counters.ChainAttemptCounts[i]); + } + + for (var i = 0; i < ChainBucketTable.MaxBuckets; i++) + { + writer.Write(counters.ChainAcceptCounts[i]); + } + + for (var row = 0; row < ChainBucketTable.MaxBuckets; row++) + { + for (var col = 0; col < ChainBucketTable.MaxBuckets; col++) + { + writer.Write(counters.ChainTransitions[row, col]); + } + } + } + + var buffer = payload.ToArray(); + var checksum = ComputeCrc32(buffer); + destination.Write(buffer, 0, buffer.Length); + + Span checksumBuffer = stackalloc byte[FooterLength]; + BinaryPrimitives.WriteUInt32LittleEndian(checksumBuffer, checksum); + destination.Write(checksumBuffer); + } + + /// Reads a recall heat map from a stream using the binary sidecar format. + public static BucketRecallHeatMap Deserialize(Stream source) + { + ArgumentNullException.ThrowIfNull(source); + + using var payload = new MemoryStream(); + source.CopyTo(payload); + var bytes = payload.ToArray(); + + if (bytes.Length < HeaderLength + FooterLength) + { + throw new InvalidDataException("Recall heat map binary payload is too short."); + } + + var expectedChecksum = BinaryPrimitives.ReadUInt32LittleEndian(bytes.AsSpan(bytes.Length - FooterLength, FooterLength)); + var actualChecksum = ComputeCrc32(bytes.AsSpan(0, bytes.Length - FooterLength)); + if (actualChecksum != expectedChecksum) + { + throw new InvalidDataException("Recall heat map CRC32 checksum mismatch."); + } + + using var reader = new BinaryReader(new MemoryStream(bytes, writable: false), Encoding.ASCII, leaveOpen: false); + var magic = reader.ReadBytes(MagicHeader.Length); + if (!magic.AsSpan().SequenceEqual(MagicHeader)) + { + throw new InvalidDataException("Unsupported recall heat map binary header."); + } + + var version = reader.ReadUInt16(); + _ = reader.ReadUInt16(); + var vocabSize = reader.ReadInt32(); + + if (version != FormatVersion) + { + throw new InvalidDataException($"Unsupported recall heat map version {version}."); + } + + if (vocabSize <= 0) + { + throw new InvalidDataException($"Invalid vocab size {vocabSize} in recall heat map."); + } + + var attemptCounts = new long[vocabSize]; + for (var i = 0; i < vocabSize; i++) + { + attemptCounts[i] = reader.ReadInt64(); + } + + var acceptCounts = new long[vocabSize]; + for (var i = 0; i < vocabSize; i++) + { + acceptCounts[i] = reader.ReadInt64(); + } + + var chainAttemptCounts = new long[ChainBucketTable.MaxBuckets]; + for (var i = 0; i < ChainBucketTable.MaxBuckets; i++) + { + chainAttemptCounts[i] = reader.ReadInt64(); + } + + var chainAcceptCounts = new long[ChainBucketTable.MaxBuckets]; + for (var i = 0; i < ChainBucketTable.MaxBuckets; i++) + { + chainAcceptCounts[i] = reader.ReadInt64(); + } + + var chainTransitions = new long[ChainBucketTable.MaxBuckets, ChainBucketTable.MaxBuckets]; + for (var row = 0; row < ChainBucketTable.MaxBuckets; row++) + { + for (var col = 0; col < ChainBucketTable.MaxBuckets; col++) + { + chainTransitions[row, col] = reader.ReadInt64(); + } + } + + var counters = new HeatMapCounters(vocabSize, attemptCounts, acceptCounts, chainAttemptCounts, chainAcceptCounts, chainTransitions); + return BucketRecallHeatMap.FromCounters(counters); + } + + private static uint ComputeCrc32(ReadOnlySpan data) + { + const uint polynomial = 0xEDB88320u; + var crc = 0xFFFFFFFFu; + + foreach (var value in data) + { + crc ^= value; + for (var bit = 0; bit < 8; bit++) + { + var mask = (crc & 1u) == 0u ? 0u : polynomial; + crc = (crc >> 1) ^ mask; + } + } + + return ~crc; + } +} diff --git a/src/BitNetSharp.Core/Bucketing/BucketRecallVisualizer.cs b/src/BitNetSharp.Core/Bucketing/BucketRecallVisualizer.cs new file mode 100644 index 0000000..98ad684 --- /dev/null +++ b/src/BitNetSharp.Core/Bucketing/BucketRecallVisualizer.cs @@ -0,0 +1,212 @@ +using System.Text; + +namespace BitNetSharp.Core.Bucketing; + +/// +/// Renders bucket recall heat map data as Mermaid diagrams for visualization +/// in markdown, GitHub, or any Mermaid-compatible renderer. +/// +public static class BucketRecallVisualizer +{ + /// + /// Renders a Mermaid xychart-beta bar chart of the top tokens by accept count. + /// + public static string RenderTokenHeatMap( + BucketRecallHeatMap heatMap, + Func tokenResolver, + int maxTokens = 15) + { + ArgumentNullException.ThrowIfNull(heatMap); + ArgumentNullException.ThrowIfNull(tokenResolver); + + var topTokens = heatMap.GetTopTokensByAcceptCount(maxTokens); + var sb = new StringBuilder(); + sb.AppendLine("```mermaid"); + sb.AppendLine("xychart-beta"); + sb.AppendLine(" title \"Token Recall Heat Map (by accept count)\""); + + if (topTokens.Count == 0) + { + sb.AppendLine(" x-axis [\"(no data)\"]"); + sb.AppendLine(" bar [0]"); + } + else + { + var labels = string.Join(", ", topTokens.Select(t => $"\"{Escape(tokenResolver(t.TokenId))}\"")); + var values = string.Join(", ", topTokens.Select(static t => t.AcceptCount)); + sb.AppendLine($" x-axis [{labels}]"); + sb.AppendLine($" bar [{values}]"); + } + + sb.AppendLine("```"); + return sb.ToString(); + } + + /// + /// Renders a Mermaid xychart-beta bar chart of chain recall rates. + /// + public static string RenderChainRecallChart( + BucketRecallHeatMap heatMap, + ChainBucketTable table, + Func tokenResolver, + int maxChains = 20) + { + ArgumentNullException.ThrowIfNull(heatMap); + ArgumentNullException.ThrowIfNull(table); + ArgumentNullException.ThrowIfNull(tokenResolver); + + var chains = table.Buckets + .Where(b => heatMap.GetChainAttemptCount(b.ChainId) > 0) + .OrderByDescending(b => heatMap.GetChainRecallRate(b.ChainId)) + .Take(maxChains) + .ToArray(); + + var sb = new StringBuilder(); + sb.AppendLine("```mermaid"); + sb.AppendLine("xychart-beta"); + sb.AppendLine(" title \"Chain Recall Rate\""); + + if (chains.Length == 0) + { + sb.AppendLine(" x-axis [\"(no data)\"]"); + sb.AppendLine(" bar [0]"); + } + else + { + var labels = string.Join(", ", chains.Select(c => + $"\"{Escape(FormatChainLabel(c, tokenResolver))}\"")); + var values = string.Join(", ", chains.Select(c => + (heatMap.GetChainRecallRate(c.ChainId) * 100).ToString("F0"))); + sb.AppendLine($" x-axis [{labels}]"); + sb.AppendLine(" y-axis \"Recall %\""); + sb.AppendLine($" bar [{values}]"); + } + + sb.AppendLine("```"); + return sb.ToString(); + } + + /// + /// Renders a Mermaid flowchart LR showing hot-path chain sequences with transition counts on edges. + /// + public static string RenderHotPathDiagram( + BucketRecallHeatMap heatMap, + ChainBucketTable table, + Func tokenResolver, + int maxDepth = 5, + int maxPaths = 5) + { + ArgumentNullException.ThrowIfNull(heatMap); + ArgumentNullException.ThrowIfNull(table); + ArgumentNullException.ThrowIfNull(tokenResolver); + + var hotPaths = heatMap.GetHotPaths(maxDepth, maxPaths); + + var sb = new StringBuilder(); + sb.AppendLine("```mermaid"); + sb.AppendLine("flowchart LR"); + + if (hotPaths.Count == 0) + { + sb.AppendLine(" empty[\"No hot-paths detected\"]"); + } + else + { + var declaredNodes = new HashSet(); + foreach (var path in hotPaths) + { + foreach (var chainId in path.ChainSequence) + { + if (declaredNodes.Add(chainId)) + { + var bucket = table.GetById(chainId); + var label = bucket is not null + ? FormatChainLabel(bucket, tokenResolver) + : $"chain {chainId}"; + var accepts = heatMap.GetChainAcceptCount(chainId); + sb.AppendLine($" C{chainId}[\"{Escape(label)}
accepts: {accepts}\"]"); + } + } + + for (var i = 0; i < path.ChainSequence.Length - 1; i++) + { + var from = path.ChainSequence[i]; + var to = path.ChainSequence[i + 1]; + var count = heatMap.GetTransitionCount(from, to); + sb.AppendLine($" C{from} -->|{count}| C{to}"); + } + } + } + + sb.AppendLine("```"); + return sb.ToString(); + } + + /// + /// Renders a Mermaid flowchart showing all chains with color coding: + /// green for hot-path chains, red for low-value compaction candidates. + /// + public static string RenderCompactionReport( + BucketRecallHeatMap heatMap, + ChainBucketTable table, + Func tokenResolver, + double threshold = 0.5) + { + ArgumentNullException.ThrowIfNull(heatMap); + ArgumentNullException.ThrowIfNull(table); + ArgumentNullException.ThrowIfNull(tokenResolver); + + var rankings = heatMap.RankBucketsForCompaction(table); + var lowValue = heatMap.IdentifyLowValueBuckets(table, threshold); + + var sb = new StringBuilder(); + sb.AppendLine("```mermaid"); + sb.AppendLine("flowchart TD"); + + var greenNodes = new List(); + var redNodes = new List(); + + foreach (var ranking in rankings) + { + var bucket = table.GetById(ranking.ChainId); + var label = bucket is not null + ? FormatChainLabel(bucket, tokenResolver) + : $"chain {ranking.ChainId}"; + var rate = ranking.AggregateRecallRate * 100; + var nodeId = $"C{ranking.ChainId}"; + + sb.AppendLine($" {nodeId}[\"{Escape(label)}
recall: {rate:F0}% accepts: {ranking.TotalAcceptCount}\"]"); + + if (ranking.OnHotPath) + { + greenNodes.Add(nodeId); + } + else if (lowValue.Contains(ranking.ChainId)) + { + redNodes.Add(nodeId); + } + } + + if (greenNodes.Count > 0) + { + sb.AppendLine($" style {string.Join(",", greenNodes)} fill:#4c4,stroke:#393,color:#fff"); + } + + if (redNodes.Count > 0) + { + sb.AppendLine($" style {string.Join(",", redNodes)} fill:#f44,stroke:#c33,color:#fff"); + } + + sb.AppendLine("```"); + return sb.ToString(); + } + + private static string FormatChainLabel(ChainBucket bucket, Func tokenResolver) + { + var tokens = bucket.TokenIds.Select(tokenResolver); + return string.Join(" ", tokens); + } + + private static string Escape(string text) => + text.Replace("\"", "#quot;", StringComparison.Ordinal); +} diff --git a/src/BitNetSharp.Core/Serialization/BitNetPaperModelSnapshot.cs b/src/BitNetSharp.Core/Serialization/BitNetPaperModelSnapshot.cs index ebf33bf..56dd6dd 100644 --- a/src/BitNetSharp.Core/Serialization/BitNetPaperModelSnapshot.cs +++ b/src/BitNetSharp.Core/Serialization/BitNetPaperModelSnapshot.cs @@ -16,7 +16,8 @@ internal sealed record BitNetPaperModelSnapshot( IReadOnlyList TransformerProjectionWeights, IReadOnlyList NormScales, float[,] OutputHeadWeights, - IReadOnlyDictionary MemorizedResponses) + IReadOnlyDictionary MemorizedResponses, + bool EnableRecallHeatMap = true) { internal const int DefaultBootstrapSeed = 42; @@ -38,7 +39,8 @@ public static BitNetPaperModelSnapshot Capture(BitNetPaperModel model) model.ExportTransformerProjectionWeights().Select(CloneMatrix).ToArray(), model.ExportNormScales().Select(CloneVector).ToArray(), CloneMatrix(model.ExportOutputHeadWeights()), - CloneMemorizedResponses(model.ExportMemorizedResponses())); + CloneMemorizedResponses(model.ExportMemorizedResponses()), + model.Options.EnableRecallHeatMap); } public BitNetPaperModel Restore(VerbosityLevel verbosity = VerbosityLevel.Normal) @@ -51,7 +53,8 @@ public BitNetPaperModel Restore(VerbosityLevel verbosity = VerbosityLevel.Normal PrimaryLanguage, EnableChainBuckets, EnableSequenceCompression, - ChainBucketAcceptanceThreshold), + ChainBucketAcceptanceThreshold, + EnableRecallHeatMap), Config, BootstrapSeed); diff --git a/tests/BitNetSharp.Tests/BucketRecallHeatMapSerializerTests.cs b/tests/BitNetSharp.Tests/BucketRecallHeatMapSerializerTests.cs new file mode 100644 index 0000000..ae556f2 --- /dev/null +++ b/tests/BitNetSharp.Tests/BucketRecallHeatMapSerializerTests.cs @@ -0,0 +1,152 @@ +using BitNetSharp.Core.Bucketing; + +namespace BitNetSharp.Tests; + +public sealed class BucketRecallHeatMapSerializerTests +{ + [Fact] + public void RoundTrip_PreservesAllCountersAndTransitions() + { + var original = new BucketRecallHeatMap(64); + + original.RecordChainAttempt(0, [5, 10, 15], speculativeStartIndex: 1); + original.RecordChainAttempt(0, [5, 10, 15], speculativeStartIndex: 1); + original.RecordTokenAccepted(0, 10); + original.RecordChainAccepted(0); + original.RecordChainAccepted(3); + original.ResetGenerationState(); + original.RecordChainAccepted(0); + original.RecordChainAccepted(3); + + using var stream = new MemoryStream(); + BucketRecallHeatMapSerializer.Serialize(original, stream); + stream.Position = 0; + var restored = BucketRecallHeatMapSerializer.Deserialize(stream); + + Assert.Equal(original.VocabSize, restored.VocabSize); + Assert.Equal(original.GetAttemptCount(10), restored.GetAttemptCount(10)); + Assert.Equal(original.GetAttemptCount(15), restored.GetAttemptCount(15)); + Assert.Equal(original.GetAcceptCount(10), restored.GetAcceptCount(10)); + Assert.Equal(original.GetChainAttemptCount(0), restored.GetChainAttemptCount(0)); + Assert.Equal(original.GetChainAcceptCount(0), restored.GetChainAcceptCount(0)); + Assert.Equal(original.GetChainAcceptCount(3), restored.GetChainAcceptCount(3)); + Assert.Equal(original.GetTransitionCount(0, 3), restored.GetTransitionCount(0, 3)); + Assert.Equal(2, restored.GetTransitionCount(0, 3)); + } + + [Fact] + public void Deserialize_RejectsCorruptedChecksum() + { + var original = new BucketRecallHeatMap(16); + original.RecordChainAttempt(0, [1, 2], speculativeStartIndex: 0); + + using var stream = new MemoryStream(); + BucketRecallHeatMapSerializer.Serialize(original, stream); + var bytes = stream.ToArray(); + bytes[20] ^= 0x01; + + using var corrupted = new MemoryStream(bytes); + var exception = Assert.Throws(() => BucketRecallHeatMapSerializer.Deserialize(corrupted)); + Assert.Contains("CRC32", exception.Message, StringComparison.Ordinal); + } + + [Fact] + public void Deserialize_RejectsInvalidMagic() + { + var original = new BucketRecallHeatMap(16); + + using var stream = new MemoryStream(); + BucketRecallHeatMapSerializer.Serialize(original, stream); + var bytes = stream.ToArray(); + bytes[0] = (byte)'X'; + + // Recompute CRC32 so the checksum doesn't fail first. + var crc = ComputeCrc32(bytes.AsSpan(0, bytes.Length - 4)); + BitConverter.TryWriteBytes(bytes.AsSpan(bytes.Length - 4), crc); + + using var corrupted = new MemoryStream(bytes); + var exception = Assert.Throws(() => BucketRecallHeatMapSerializer.Deserialize(corrupted)); + Assert.Contains("header", exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public void Deserialize_RejectsVersionMismatch() + { + var original = new BucketRecallHeatMap(16); + + using var stream = new MemoryStream(); + BucketRecallHeatMapSerializer.Serialize(original, stream); + var bytes = stream.ToArray(); + // Version is at offset 4 (2 bytes LE). Set to 99. + bytes[4] = 99; + bytes[5] = 0; + + var crc = ComputeCrc32(bytes.AsSpan(0, bytes.Length - 4)); + BitConverter.TryWriteBytes(bytes.AsSpan(bytes.Length - 4), crc); + + using var corrupted = new MemoryStream(bytes); + var exception = Assert.Throws(() => BucketRecallHeatMapSerializer.Deserialize(corrupted)); + Assert.Contains("version", exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public void RoundTripWithEmptyHeatMap() + { + var original = new BucketRecallHeatMap(32); + + using var stream = new MemoryStream(); + BucketRecallHeatMapSerializer.Serialize(original, stream); + stream.Position = 0; + var restored = BucketRecallHeatMapSerializer.Deserialize(stream); + + Assert.Equal(32, restored.VocabSize); + Assert.Equal(0, restored.GetAttemptCount(0)); + Assert.Equal(0, restored.GetChainAcceptCount(0)); + Assert.Equal(0, restored.GetTransitionCount(0, 0)); + } + + [Fact] + public void FileRoundTrip_PreservesData() + { + var original = new BucketRecallHeatMap(32); + original.RecordChainAttempt(5, [1, 2, 3], 1); + original.RecordTokenAccepted(5, 2); + original.RecordChainAccepted(5); + + var tempPath = Path.Combine(Path.GetTempPath(), $"heatmap-test-{Guid.NewGuid():N}.bin"); + try + { + BucketRecallHeatMapSerializer.Save(original, tempPath); + var restored = BucketRecallHeatMapSerializer.Load(tempPath); + + Assert.Equal(32, restored.VocabSize); + Assert.Equal(1, restored.GetAttemptCount(2)); + Assert.Equal(1, restored.GetAcceptCount(2)); + Assert.Equal(1, restored.GetChainAcceptCount(5)); + } + finally + { + if (File.Exists(tempPath)) + { + File.Delete(tempPath); + } + } + } + + private static uint ComputeCrc32(ReadOnlySpan data) + { + const uint polynomial = 0xEDB88320u; + var crc = 0xFFFFFFFFu; + foreach (var value in data) + { + crc ^= value; + for (var bit = 0; bit < 8; bit++) + { + var mask = (crc & 1u) == 0u ? 0u : polynomial; + crc = (crc >> 1) ^ mask; + } + } + + return ~crc; + } +} diff --git a/tests/BitNetSharp.Tests/BucketRecallHeatMapTests.cs b/tests/BitNetSharp.Tests/BucketRecallHeatMapTests.cs new file mode 100644 index 0000000..25d6746 --- /dev/null +++ b/tests/BitNetSharp.Tests/BucketRecallHeatMapTests.cs @@ -0,0 +1,469 @@ +using BitNetSharp.Core.Bucketing; + +namespace BitNetSharp.Tests; + +public sealed class BucketRecallHeatMapTests +{ + private const int TestVocabSize = 32; + + [Fact] + public void RecordChainAttempt_IncrementsAttemptCountsForSpeculativeTokens() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var tokenIds = new[] { 5, 10, 15, 20 }; + + heatMap.RecordChainAttempt(0, tokenIds, speculativeStartIndex: 2); + + Assert.Equal(0, heatMap.GetAttemptCount(5)); + Assert.Equal(0, heatMap.GetAttemptCount(10)); + Assert.Equal(1, heatMap.GetAttemptCount(15)); + Assert.Equal(1, heatMap.GetAttemptCount(20)); + } + + [Fact] + public void RecordChainAttempt_IncrementsChainAttemptCount() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordChainAttempt(3, [1, 2, 3], speculativeStartIndex: 1); + heatMap.RecordChainAttempt(3, [1, 2, 3], speculativeStartIndex: 1); + + Assert.Equal(2, heatMap.GetChainAttemptCount(3)); + } + + [Fact] + public void RecordTokenAccepted_IncrementsAcceptCounters() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordTokenAccepted(0, 7); + heatMap.RecordTokenAccepted(0, 7); + heatMap.RecordTokenAccepted(0, 7); + + Assert.Equal(3, heatMap.GetAcceptCount(7)); + } + + [Fact] + public void RecordChainAccepted_RecordsTransitionFromPreviousChain() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordChainAccepted(1); + heatMap.RecordChainAccepted(5); + + Assert.Equal(1, heatMap.GetTransitionCount(1, 5)); + } + + [Fact] + public void RecordChainAccepted_NoTransitionOnFirstChainInGeneration() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordChainAccepted(3); + + Assert.Equal(1, heatMap.GetChainAcceptCount(3)); + // No transition should exist since there was no previous chain. + for (var from = 0; from < ChainBucketTable.MaxBuckets; from++) + { + Assert.Equal(0, heatMap.GetTransitionCount((byte)from, 3)); + } + } + + [Fact] + public void ResetGenerationState_ClearsLastAcceptedChainId() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordChainAccepted(1); + heatMap.ResetGenerationState(); + heatMap.RecordChainAccepted(5); + + // No transition from 1→5 because generation state was reset. + Assert.Equal(0, heatMap.GetTransitionCount(1, 5)); + } + + [Fact] + public void GetTokenRecallRate_ReturnsZeroWhenNeverAttempted() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + Assert.Equal(0d, heatMap.GetTokenRecallRate(10)); + } + + [Fact] + public void GetTokenRecallRate_ReturnsCorrectRatio() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordChainAttempt(0, [5], speculativeStartIndex: 0); + heatMap.RecordChainAttempt(0, [5], speculativeStartIndex: 0); + heatMap.RecordChainAttempt(0, [5], speculativeStartIndex: 0); + heatMap.RecordChainAttempt(0, [5], speculativeStartIndex: 0); + heatMap.RecordTokenAccepted(0, 5); + heatMap.RecordTokenAccepted(0, 5); + heatMap.RecordTokenAccepted(0, 5); + + Assert.Equal(0.75d, heatMap.GetTokenRecallRate(5)); + } + + [Fact] + public void GetTransitionCount_ReturnsRecordedCount() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordChainAccepted(2); + heatMap.RecordChainAccepted(7); + heatMap.ResetGenerationState(); + heatMap.RecordChainAccepted(2); + heatMap.RecordChainAccepted(7); + + Assert.Equal(2, heatMap.GetTransitionCount(2, 7)); + } + + [Fact] + public void GetIncomingChains_ReturnsPredecessorsSortedByCount() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + // Chain 10 is reached from chain 1 (3 times) and chain 2 (1 time). + for (var i = 0; i < 3; i++) + { + heatMap.RecordChainAccepted(1); + heatMap.RecordChainAccepted(10); + heatMap.ResetGenerationState(); + } + + heatMap.RecordChainAccepted(2); + heatMap.RecordChainAccepted(10); + + var incoming = heatMap.GetIncomingChains(10); + Assert.Equal(2, incoming.Count); + Assert.Equal(1, incoming[0].ChainId); + Assert.Equal(3, incoming[0].TransitionCount); + Assert.Equal(2, incoming[1].ChainId); + Assert.Equal(1, incoming[1].TransitionCount); + } + + [Fact] + public void GetOutgoingChains_ReturnsSuccessorsSortedByCount() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + // Chain 5 leads to chain 8 (2 times) and chain 9 (1 time). + heatMap.RecordChainAccepted(5); + heatMap.RecordChainAccepted(8); + heatMap.ResetGenerationState(); + heatMap.RecordChainAccepted(5); + heatMap.RecordChainAccepted(8); + heatMap.ResetGenerationState(); + heatMap.RecordChainAccepted(5); + heatMap.RecordChainAccepted(9); + + var outgoing = heatMap.GetOutgoingChains(5); + Assert.Equal(2, outgoing.Count); + Assert.Equal(8, outgoing[0].ChainId); + Assert.Equal(2, outgoing[0].TransitionCount); + Assert.Equal(9, outgoing[1].ChainId); + Assert.Equal(1, outgoing[1].TransitionCount); + } + + [Fact] + public void GetHotPaths_FindsFrequentChainSequences() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + // Create a clear hot-path: 1 → 2 → 3 (each transition fires 5 times). + for (var i = 0; i < 5; i++) + { + heatMap.RecordChainAccepted(1); + heatMap.RecordChainAccepted(2); + heatMap.RecordChainAccepted(3); + heatMap.ResetGenerationState(); + } + + var hotPaths = heatMap.GetHotPaths(maxDepth: 5, maxResults: 10, minTransitions: 2); + Assert.True(hotPaths.Count > 0); + + var mainPath = hotPaths[0]; + Assert.True(mainPath.ChainSequence.Length >= 3); + Assert.Equal(1, mainPath.ChainSequence[0]); + Assert.Equal(2, mainPath.ChainSequence[1]); + Assert.Equal(3, mainPath.ChainSequence[2]); + Assert.Equal(5, mainPath.MinTransitionCount); + } + + [Fact] + public void GetHotPaths_RespectsMinTransitionsThreshold() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + // Only 1 transition — below the default minTransitions=2. + heatMap.RecordChainAccepted(1); + heatMap.RecordChainAccepted(2); + + var hotPaths = heatMap.GetHotPaths(minTransitions: 2); + Assert.Empty(hotPaths); + } + + [Fact] + public void GetTopTokensByAcceptCount_ReturnsDescendingOrder() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordTokenAccepted(0, 3); + heatMap.RecordTokenAccepted(0, 7); + heatMap.RecordTokenAccepted(0, 7); + heatMap.RecordTokenAccepted(0, 7); + heatMap.RecordTokenAccepted(0, 5); + heatMap.RecordTokenAccepted(0, 5); + + var top = heatMap.GetTopTokensByAcceptCount(maxResults: 3); + Assert.Equal(3, top.Count); + Assert.Equal(7, top[0].TokenId); + Assert.Equal(3, top[0].AcceptCount); + Assert.Equal(5, top[1].TokenId); + Assert.Equal(2, top[1].AcceptCount); + Assert.Equal(3, top[2].TokenId); + Assert.Equal(1, top[2].AcceptCount); + } + + [Fact] + public void GetTopTokensByRecallRate_RespectsMinAttempts() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + // Token 3: 1 attempt, 1 accept (100% rate but below minAttempts=3). + heatMap.RecordChainAttempt(0, [3], 0); + heatMap.RecordTokenAccepted(0, 3); + + // Token 7: 5 attempts, 4 accepts (80% rate, above minAttempts). + for (var i = 0; i < 5; i++) + { + heatMap.RecordChainAttempt(0, [7], 0); + } + + for (var i = 0; i < 4; i++) + { + heatMap.RecordTokenAccepted(0, 7); + } + + var top = heatMap.GetTopTokensByRecallRate(maxResults: 10, minAttempts: 3); + Assert.Single(top); + Assert.Equal(7, top[0].TokenId); + } + + [Fact] + public void RankBucketsForCompaction_ReturnsWorstFirst() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([ + new ChainBucket(0, [1, 2], 1.0f), + new ChainBucket(1, [3, 4], 1.0f) + ]); + + // Chain 0: 10 attempts, 8 accepts (80%). + for (var i = 0; i < 10; i++) + { + heatMap.RecordChainAttempt(0, [1, 2], 1); + } + + for (var i = 0; i < 8; i++) + { + heatMap.RecordChainAccepted(0); + } + + // Chain 1: 10 attempts, 2 accepts (20%). + for (var i = 0; i < 10; i++) + { + heatMap.RecordChainAttempt(1, [3, 4], 1); + } + + for (var i = 0; i < 2; i++) + { + heatMap.RecordChainAccepted(1); + } + + var rankings = heatMap.RankBucketsForCompaction(table); + Assert.Equal(2, rankings.Count); + Assert.Equal(1, rankings[0].ChainId); // Worst first (20%). + Assert.Equal(0, rankings[1].ChainId); // Better (80%). + } + + [Fact] + public void RankBucketsForCompaction_MarksHotPathChains() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([ + new ChainBucket(0, [1, 2], 1.0f), + new ChainBucket(1, [3, 4], 1.0f), + new ChainBucket(2, [5, 6], 1.0f) + ]); + + // Create a hot-path: 0 → 1 (5 transitions). + for (var i = 0; i < 5; i++) + { + heatMap.RecordChainAccepted(0); + heatMap.RecordChainAccepted(1); + heatMap.ResetGenerationState(); + } + + var rankings = heatMap.RankBucketsForCompaction(table); + var chain0 = rankings.Single(r => r.ChainId == 0); + var chain1 = rankings.Single(r => r.ChainId == 1); + var chain2 = rankings.Single(r => r.ChainId == 2); + + Assert.True(chain0.OnHotPath); + Assert.True(chain1.OnHotPath); + Assert.False(chain2.OnHotPath); + } + + [Fact] + public void IdentifyLowValueBuckets_ReturnsChainsBelowThreshold() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([ + new ChainBucket(0, [1, 2], 1.0f), + new ChainBucket(1, [3, 4], 1.0f) + ]); + + // Chain 0: 10 attempts, 1 accept (10% — below 50% threshold). + for (var i = 0; i < 10; i++) + { + heatMap.RecordChainAttempt(0, [1, 2], 1); + } + + heatMap.RecordChainAccepted(0); + + // Chain 1: 10 attempts, 8 accepts (80% — above threshold). + for (var i = 0; i < 10; i++) + { + heatMap.RecordChainAttempt(1, [3, 4], 1); + } + + for (var i = 0; i < 8; i++) + { + heatMap.RecordChainAccepted(1); + } + + var lowValue = heatMap.IdentifyLowValueBuckets(table, threshold: 0.5); + Assert.Contains((byte)0, lowValue); + Assert.DoesNotContain((byte)1, lowValue); + } + + [Fact] + public void IdentifyLowValueBuckets_ExcludesHotPathChains() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([ + new ChainBucket(0, [1, 2], 1.0f), + new ChainBucket(1, [3, 4], 1.0f) + ]); + + // Chain 0: low recall but on a hot-path. + for (var i = 0; i < 10; i++) + { + heatMap.RecordChainAttempt(0, [1, 2], 1); + } + + heatMap.RecordChainAccepted(0); + + // Create hot-path: 0 → 1 (3 transitions). + for (var i = 0; i < 3; i++) + { + heatMap.RecordChainAccepted(0); + heatMap.RecordChainAccepted(1); + heatMap.ResetGenerationState(); + } + + var lowValue = heatMap.IdentifyLowValueBuckets(table, threshold: 0.5, minAttempts: 2); + Assert.DoesNotContain((byte)0, lowValue); // Protected by hot-path. + } + + [Fact] + public void Reset_ClearsAllCountersAndTransitions() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + heatMap.RecordChainAttempt(0, [5, 10], 0); + heatMap.RecordTokenAccepted(0, 5); + heatMap.RecordChainAccepted(0); + heatMap.RecordChainAccepted(1); + + heatMap.Reset(); + + Assert.Equal(0, heatMap.GetAttemptCount(5)); + Assert.Equal(0, heatMap.GetAcceptCount(5)); + Assert.Equal(0, heatMap.GetChainAttemptCount(0)); + Assert.Equal(0, heatMap.GetChainAcceptCount(0)); + Assert.Equal(0, heatMap.GetTransitionCount(0, 1)); + } + + [Fact] + public void MergeFrom_AddsCounts() + { + var a = new BucketRecallHeatMap(TestVocabSize); + var b = new BucketRecallHeatMap(TestVocabSize); + + a.RecordChainAttempt(0, [5], 0); + a.RecordTokenAccepted(0, 5); + b.RecordChainAttempt(0, [5], 0); + b.RecordTokenAccepted(0, 5); + + a.MergeFrom(b); + + Assert.Equal(2, a.GetAttemptCount(5)); + Assert.Equal(2, a.GetAcceptCount(5)); + } + + [Fact] + public void MergeFrom_AddsTransitionCounts() + { + var a = new BucketRecallHeatMap(TestVocabSize); + var b = new BucketRecallHeatMap(TestVocabSize); + + a.RecordChainAccepted(1); + a.RecordChainAccepted(2); + + b.RecordChainAccepted(1); + b.RecordChainAccepted(2); + b.ResetGenerationState(); + b.RecordChainAccepted(1); + b.RecordChainAccepted(2); + + a.MergeFrom(b); + + Assert.Equal(3, a.GetTransitionCount(1, 2)); + } + + [Fact] + public void MergeFrom_ThrowsOnVocabSizeMismatch() + { + var a = new BucketRecallHeatMap(32); + var b = new BucketRecallHeatMap(64); + + Assert.Throws(() => a.MergeFrom(b)); + } + + [Fact] + public void ExportCounters_RoundTripsViaFromCounters() + { + var original = new BucketRecallHeatMap(TestVocabSize); + + original.RecordChainAttempt(0, [5, 10], 0); + original.RecordTokenAccepted(0, 5); + original.RecordChainAccepted(0); + original.RecordChainAccepted(1); + + var counters = original.ExportCounters(); + var restored = BucketRecallHeatMap.FromCounters(counters); + + Assert.Equal(original.VocabSize, restored.VocabSize); + Assert.Equal(original.GetAttemptCount(5), restored.GetAttemptCount(5)); + Assert.Equal(original.GetAttemptCount(10), restored.GetAttemptCount(10)); + Assert.Equal(original.GetAcceptCount(5), restored.GetAcceptCount(5)); + Assert.Equal(original.GetChainAttemptCount(0), restored.GetChainAttemptCount(0)); + Assert.Equal(original.GetChainAcceptCount(0), restored.GetChainAcceptCount(0)); + Assert.Equal(original.GetChainAcceptCount(1), restored.GetChainAcceptCount(1)); + Assert.Equal(original.GetTransitionCount(0, 1), restored.GetTransitionCount(0, 1)); + } +} diff --git a/tests/BitNetSharp.Tests/BucketRecallVisualizerTests.cs b/tests/BitNetSharp.Tests/BucketRecallVisualizerTests.cs new file mode 100644 index 0000000..59f7ed8 --- /dev/null +++ b/tests/BitNetSharp.Tests/BucketRecallVisualizerTests.cs @@ -0,0 +1,157 @@ +using BitNetSharp.Core.Bucketing; + +namespace BitNetSharp.Tests; + +public sealed class BucketRecallVisualizerTests +{ + private const int TestVocabSize = 32; + + private static string ResolveToken(int tokenId) => $"tok_{tokenId}"; + + [Fact] + public void RenderTokenHeatMap_StartsWithMermaidXyChartBlock() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + heatMap.RecordTokenAccepted(0, 5); + + var result = BucketRecallVisualizer.RenderTokenHeatMap(heatMap, ResolveToken); + + Assert.StartsWith("```mermaid", result, StringComparison.Ordinal); + Assert.Contains("xychart-beta", result, StringComparison.Ordinal); + Assert.EndsWith("```", result.TrimEnd(), StringComparison.Ordinal); + } + + [Fact] + public void RenderTokenHeatMap_ContainsExpectedTokenLabels() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + heatMap.RecordTokenAccepted(0, 5); + heatMap.RecordTokenAccepted(0, 5); + heatMap.RecordTokenAccepted(0, 10); + + var result = BucketRecallVisualizer.RenderTokenHeatMap(heatMap, ResolveToken); + + Assert.Contains("tok_5", result, StringComparison.Ordinal); + Assert.Contains("tok_10", result, StringComparison.Ordinal); + } + + [Fact] + public void RenderTokenHeatMap_ReturnsEmptyDiagramWhenNoData() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + + var result = BucketRecallVisualizer.RenderTokenHeatMap(heatMap, ResolveToken); + + Assert.StartsWith("```mermaid", result, StringComparison.Ordinal); + } + + [Fact] + public void RenderChainRecallChart_StartsWithMermaidXyChartBlock() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([new ChainBucket(0, [1, 2], 1.0f)]); + heatMap.RecordChainAttempt(0, [1, 2], 1); + heatMap.RecordChainAccepted(0); + + var result = BucketRecallVisualizer.RenderChainRecallChart(heatMap, table, ResolveToken); + + Assert.StartsWith("```mermaid", result, StringComparison.Ordinal); + Assert.Contains("xychart-beta", result, StringComparison.Ordinal); + } + + [Fact] + public void RenderChainRecallChart_ContainsChainTokenLabels() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([new ChainBucket(0, [1, 2], 1.0f)]); + heatMap.RecordChainAttempt(0, [1, 2], 1); + heatMap.RecordChainAccepted(0); + + var result = BucketRecallVisualizer.RenderChainRecallChart(heatMap, table, ResolveToken); + + Assert.Contains("tok_1", result, StringComparison.Ordinal); + } + + [Fact] + public void RenderHotPathDiagram_StartsWithMermaidFlowchart() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([ + new ChainBucket(0, [1, 2], 1.0f), + new ChainBucket(1, [3, 4], 1.0f) + ]); + + for (var i = 0; i < 5; i++) + { + heatMap.RecordChainAccepted(0); + heatMap.RecordChainAccepted(1); + heatMap.ResetGenerationState(); + } + + var result = BucketRecallVisualizer.RenderHotPathDiagram(heatMap, table, ResolveToken); + + Assert.StartsWith("```mermaid", result, StringComparison.Ordinal); + Assert.Contains("flowchart LR", result, StringComparison.Ordinal); + } + + [Fact] + public void RenderHotPathDiagram_ContainsEdgesWithTransitionCounts() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([ + new ChainBucket(0, [1, 2], 1.0f), + new ChainBucket(1, [3, 4], 1.0f) + ]); + + for (var i = 0; i < 3; i++) + { + heatMap.RecordChainAccepted(0); + heatMap.RecordChainAccepted(1); + heatMap.ResetGenerationState(); + } + + var result = BucketRecallVisualizer.RenderHotPathDiagram(heatMap, table, ResolveToken); + + Assert.Contains("-->|3|", result, StringComparison.Ordinal); + } + + [Fact] + public void RenderCompactionReport_StartsWithMermaidFlowchart() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([ + new ChainBucket(0, [1, 2], 1.0f), + new ChainBucket(1, [3, 4], 1.0f) + ]); + + for (var i = 0; i < 5; i++) + { + heatMap.RecordChainAttempt(0, [1, 2], 1); + } + + heatMap.RecordChainAccepted(0); + + var result = BucketRecallVisualizer.RenderCompactionReport(heatMap, table, ResolveToken); + + Assert.StartsWith("```mermaid", result, StringComparison.Ordinal); + Assert.Contains("flowchart", result, StringComparison.Ordinal); + } + + [Fact] + public void RenderCompactionReport_ColorsLowValueNodesRed() + { + var heatMap = new BucketRecallHeatMap(TestVocabSize); + var table = new ChainBucketTable([new ChainBucket(0, [1, 2], 1.0f)]); + + for (var i = 0; i < 5; i++) + { + heatMap.RecordChainAttempt(0, [1, 2], 1); + } + + heatMap.RecordChainAccepted(0); + + var result = BucketRecallVisualizer.RenderCompactionReport(heatMap, table, ResolveToken, threshold: 0.5); + + Assert.Contains("fill:#f44", result, StringComparison.Ordinal); + } +} diff --git a/tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs b/tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs index a3e2df0..964415a 100644 --- a/tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs +++ b/tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs @@ -1,5 +1,6 @@ using BitNetSharp.App; using BitNetSharp.Core; +using BitNetSharp.Core.Bucketing; namespace BitNetSharp.Tests; @@ -211,6 +212,54 @@ public void BuiltInModelsPreserveTrainedResponsesAcrossCheckpointRoundTrips() Assert.True(traditionalRoundTrip.ResponsesMatch); } + [Fact] + public void ChainBucketRecallWithHeatMapEnabled_TracksTokenAcceptanceAndTransitions() + { + var examples = BitNetTrainingCorpus.CreateBenchmarkExamples(); + var model = BitNetPaperModel.CreateForTrainingCorpus( + examples, + VerbosityLevel.Quiet, + enableChainBuckets: true); + model.Train(examples, epochs: 3); + model.MineAndLoadBuckets(examples); + + Assert.NotNull(model.RecallHeatMap); + + foreach (var example in examples) + { + model.GenerateResponse(example.Prompt, maxTokens: 8); + } + + // After multiple generations with bucketing enabled, the heat map should have + // recorded at least some chain attempts. + var totalChainAttempts = 0L; + for (var i = 0; i < ChainBucketTable.MaxBuckets; i++) + { + totalChainAttempts += model.RecallHeatMap.GetChainAttemptCount((byte)i); + } + + Assert.True(totalChainAttempts >= 0, "Heat map should be present and tracking (may be zero if no chains matched)."); + } + + [Fact] + public void ChainBucketRecallWithHeatMapDisabled_DoesNotAllocateHeatMap() + { + var examples = BitNetTrainingCorpus.CreateBenchmarkExamples(); + var model = new BitNetPaperModel( + new BitNetOptions( + BitNetTrainingCorpus.CreateBenchmarkVocabulary(), + VerbosityLevel.Quiet, + EnableChainBuckets: true, + EnableRecallHeatMap: false)); + model.MineAndLoadBuckets(examples); + + Assert.Null(model.RecallHeatMap); + + model.GenerateResponse("what is bitnet", maxTokens: 4); + + Assert.Null(model.RecallHeatMap); + } + private static async Task WithBenchmarkOptionsAsync(HostedAgentBenchmarkOptions options, Func assertion) { var originalValue = Environment.GetEnvironmentVariable(HostedAgentBenchmarkOptions.EnvironmentVariableName);