diff --git a/.vitepress/config/apiReferenceSidebar.ts b/.vitepress/config/apiReferenceSidebar.ts index 6b8c35ed..428d3aae 100644 --- a/.vitepress/config/apiReferenceSidebar.ts +++ b/.vitepress/config/apiReferenceSidebar.ts @@ -53,6 +53,7 @@ const chatWrappersOrder = [ "Llama3ChatWrapper", "Llama2ChatWrapper", "MistralChatWrapper", + "Gemma4ChatWrapper", "GemmaChatWrapper", "ChatMLChatWrapper", "FalconChatWrapper", diff --git a/src/chatWrappers/Gemma4ChatWrapper.ts b/src/chatWrappers/Gemma4ChatWrapper.ts new file mode 100644 index 00000000..16f93e92 --- /dev/null +++ b/src/chatWrappers/Gemma4ChatWrapper.ts @@ -0,0 +1,253 @@ +import {ChatWrapper, ChatWrapperJinjaMatchConfiguration} from "../ChatWrapper.js"; +import { + ChatModelFunctionCall, ChatModelFunctions, ChatModelResponse, ChatWrapperGenerateContextStateOptions, ChatWrapperGeneratedContextState, + ChatWrapperSettings +} from "../types.js"; +import {LlamaText, SpecialToken, SpecialTokensText} from "../utils/LlamaText.js"; +import {jsonDumps} from "./utils/jsonDumps.js"; + +// source: https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4 +export class Gemma4ChatWrapper extends ChatWrapper { + public readonly wrapperName: string = "Gemma 4"; + + public readonly reasoning: boolean; + public readonly keepOnlyLastThought: boolean; + + public override readonly settings: ChatWrapperSettings = { + supportsSystemMessages: true, + functions: { + call: { + optionalPrefixSpace: false, + prefix: LlamaText(new SpecialTokensText("<|tool_call>call:")), + paramsPrefix: "{", + suffix: LlamaText(new SpecialTokensText("}")), + emptyCallParamsPlaceholder: undefined + }, + result: { + prefix: LlamaText(new SpecialTokensText("response:"), "{{functionName}}", "{"), + suffix: LlamaText(new SpecialTokensText("}")) + } + }, + segments: { + reiterateStackAfterFunctionCalls: true, + thought: { + prefix: LlamaText(new SpecialTokensText("<|channel>thought\n")), + suffix: LlamaText(new SpecialTokensText("")) + } + } + }; + + public constructor(options: { + /** + * Whether to promote the model to perform reasoning. + * + * Defaults to `true`. + */ + reasoning?: boolean, + + /** + * Whether to keep only the chain of thought from the last model response. + * + * Setting this to `false` will keep all the chain of thoughts from the model responses in the context state. + * + * Defaults to `true`. + */ + keepOnlyLastThought?: boolean + } = {}) { + super(); + + const { + reasoning = true, + keepOnlyLastThought = true + } = options; + + this.reasoning = reasoning; + this.keepOnlyLastThought = keepOnlyLastThought; + } + + public override generateContextState({ + chatHistory, availableFunctions, documentFunctionParams + }: ChatWrapperGenerateContextStateOptions): ChatWrapperGeneratedContextState { + const hasFunctions = Object.keys(availableFunctions ?? {}).length > 0; + const modifiedChatHistory = chatHistory.slice(); + + let systemMessage: LlamaText = LlamaText(); + if (modifiedChatHistory[0]?.type === "system") { + systemMessage = LlamaText.fromJSON(modifiedChatHistory[0].text); + modifiedChatHistory.shift(); + } + + if (hasFunctions) + systemMessage = LlamaText([ + systemMessage, + this.generateAvailableFunctionsSystemText(availableFunctions ?? {}, {documentParams: documentFunctionParams}) + ]); + + if (this.reasoning) + systemMessage = LlamaText([ + new SpecialTokensText("<|think|>"), + systemMessage + ]); + + if (systemMessage.values.length > 0) + modifiedChatHistory.unshift({ + type: "system", + text: systemMessage.toJSON() + }); + + const contextContent: LlamaText[] = [ + LlamaText(new SpecialToken("BOS")) + ]; + + for (let i = 0; i < modifiedChatHistory.length; i++) { + const isLastItem = i === modifiedChatHistory.length - 1; + const item = modifiedChatHistory[i]; + + if (item == null) + continue; + + if (item.type === "system") + contextContent.push( + LlamaText([ + new SpecialTokensText("<|turn>system\n"), + LlamaText.fromJSON(item.text), + isLastItem + ? LlamaText([]) + : new SpecialTokensText("\n") + ]) + ); + else if (item.type === "user") + contextContent.push( + LlamaText([ + new SpecialTokensText("<|turn>user\n"), + item.text, + isLastItem + ? LlamaText([]) + : new SpecialTokensText("\n") + ]) + ); + else if (item.type === "model") + contextContent.push(this._getModelResponse(item.response, true, isLastItem, this.keepOnlyLastThought)); + else + void (item satisfies never); + } + + return { + contextText: LlamaText(contextContent), + stopGenerationTriggers: [ + LlamaText(new SpecialToken("EOS")), + LlamaText(new SpecialToken("EOT")), + LlamaText(new SpecialTokensText("")), + LlamaText(new SpecialTokensText("\n")), + LlamaText("<|return|>") + ] + }; + } + + public override generateAvailableFunctionsSystemText(availableFunctions: ChatModelFunctions, {documentParams = true}: { + documentParams?: boolean + }): LlamaText { + return LlamaText( + Object.entries(availableFunctions) + .map(([name, definition]) => { + return LlamaText([ + new SpecialTokensText("<|tool>"), + "declaration:", name, "{", + jsonDumps({ + description: definition.description || undefined, + parameters: documentParams + ? (definition.params || {}) + : undefined + }), + "}", new SpecialTokensText("") + ]); + }) + ); + } + + public override generateModelResponseText(modelResponse: ChatModelResponse["response"], useRawValues: boolean = true): LlamaText { + return this._getModelResponse(modelResponse, useRawValues, false, false); + } + + /** @internal */ + private _getModelResponse( + modelResponse: ChatModelResponse["response"], + useRawValues: boolean, + isLastItem: boolean, + keepOnlyLastThought: boolean + ) { + const res: LlamaText[] = [ + LlamaText(new SpecialTokensText("<|turn>model\n")) + ]; + const pendingFunctionCalls: ChatModelFunctionCall[] = []; + + const addPendingFunctions = () => { + if (pendingFunctionCalls.length === 0) + return; + + res.push(this.generateFunctionCallsAndResults(pendingFunctionCalls, useRawValues)); + + pendingFunctionCalls.length = 0; + }; + + for (let index = 0; index < modelResponse.length; index++) { + const isLastResponse = index === modelResponse.length - 1; + const response = modelResponse[index]; + + if (response == null) + continue; + else if (response === "" && (!isLastResponse || !isLastItem)) + continue; + + if (typeof response === "string") { + addPendingFunctions(); + res.push(LlamaText(response)); + } else if (response.type === "segment") { + addPendingFunctions(); + + if (response.ended && response.raw != null && useRawValues) + res.push(LlamaText.fromJSON(response.raw)); + else if (response.segmentType === "thought") { + if (keepOnlyLastThought && !isLastItem) + continue; + + res.push( + LlamaText([ + new SpecialTokensText("<|channel>thought"), + response.text, + (isLastItem && !response.ended) + ? LlamaText([]) + : new SpecialTokensText("") + ]) + ); + } else if (response.segmentType === "comment") + continue; // unsupported + else + void (response.segmentType satisfies never); + } else if (response.type === "functionCall") { + if (response.startsNewChunk) + addPendingFunctions(); + + pendingFunctionCalls.push(response); + } else + void (response satisfies never); + } + + addPendingFunctions(); + + return LlamaText(res); + } + + /** @internal */ + public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate(): ChatWrapperJinjaMatchConfiguration { + return [ + [{}, {}], + [{reasoning: false}, {}], + [ + {reasoning: true}, + {}, + {additionalRenderParameters: {"enable_thinking": true}} + ] + ]; + } +} diff --git a/src/chatWrappers/utils/resolveChatWrapper.ts b/src/chatWrappers/utils/resolveChatWrapper.ts index 8cadf9e8..b1faa330 100644 --- a/src/chatWrappers/utils/resolveChatWrapper.ts +++ b/src/chatWrappers/utils/resolveChatWrapper.ts @@ -7,6 +7,7 @@ import {FalconChatWrapper} from "../FalconChatWrapper.js"; import {FunctionaryChatWrapper} from "../FunctionaryChatWrapper.js"; import {AlpacaChatWrapper} from "../AlpacaChatWrapper.js"; import {GemmaChatWrapper} from "../GemmaChatWrapper.js"; +import {Gemma4ChatWrapper} from "../Gemma4ChatWrapper.js"; import {JinjaTemplateChatWrapper, JinjaTemplateChatWrapperOptions} from "../generic/JinjaTemplateChatWrapper.js"; import {TemplateChatWrapper} from "../generic/TemplateChatWrapper.js"; import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; @@ -27,7 +28,7 @@ import type {GgufFileInfo} from "../../gguf/types/GgufFileInfoTypes.js"; export const specializedChatWrapperTypeNames = Object.freeze([ "general", "deepSeek", "qwen", "llama3.2-lightweight", "llama3.1", "llama3", "llama2Chat", "mistral", "alpacaChat", "functionary", - "chatML", "falconChat", "gemma", "harmony", "seed" + "chatML", "falconChat", "gemma4", "gemma", "harmony", "seed" ] as const); export type SpecializedChatWrapperTypeName = (typeof specializedChatWrapperTypeNames)[number]; @@ -56,6 +57,7 @@ export const chatWrappers = Object.freeze({ "functionary": FunctionaryChatWrapper, "chatML": ChatMLChatWrapper, "falconChat": FalconChatWrapper, + "gemma4": Gemma4ChatWrapper, "gemma": GemmaChatWrapper, "harmony": HarmonyChatWrapper, "seed": SeedChatWrapper, @@ -70,7 +72,8 @@ const chatWrapperToConfigType = new Map( ); const specializedChatWrapperRelatedTexts = { - "harmony": ["gpt", "gpt-oss"] + "harmony": ["gpt", "gpt-oss"], + "gemma4": ["gemma 4", "gemma-4"] } satisfies Partial>; export type BuiltInChatWrapperType = InstanceType; @@ -364,6 +367,8 @@ export function resolveChatWrapper( return createSpecializedChatWrapper(Llama3ChatWrapper); else if (includesText(modelNames, ["Mistral", "Mistral Large", "Mistral Large Instruct", "Mistral-Large", "Codestral"])) return createSpecializedChatWrapper(MistralChatWrapper); + else if (includesText(modelNames, ["Gemma 4", "Gemma-4", "gemma-4"])) + return createSpecializedChatWrapper(Gemma4ChatWrapper); else if (includesText(modelNames, ["Gemma", "Gemma 2"])) return createSpecializedChatWrapper(GemmaChatWrapper); else if (includesText(modelNames, ["gpt-oss", "Gpt Oss", "Gpt-Oss", "openai_gpt-oss", "Openai_Gpt Oss", "openai.gpt-oss", "Openai.Gpt Oss"])) @@ -381,6 +386,8 @@ export function resolveChatWrapper( return createSpecializedChatWrapper(SeedChatWrapper); else if (modelJinjaTemplate.includes("<|start|>") && modelJinjaTemplate.includes("<|channel|>")) return createSpecializedChatWrapper(HarmonyChatWrapper); + else if (modelJinjaTemplate.includes("<|turn>") && modelJinjaTemplate.includes("<|tool_call>call:")) + return createSpecializedChatWrapper(Gemma4ChatWrapper); else if (modelJinjaTemplate.includes("<|im_start|>")) return createSpecializedChatWrapper(ChatMLChatWrapper); else if (modelJinjaTemplate.includes("[INST]")) @@ -430,9 +437,12 @@ export function resolveChatWrapper( return createSpecializedChatWrapper(FunctionaryChatWrapper); else if (lowercaseName === "dolphin" && splitLowercaseSubType.includes("mistral")) return createSpecializedChatWrapper(ChatMLChatWrapper); - else if (lowercaseName === "gemma") + else if (lowercaseName === "gemma") { + if (firstSplitLowercaseSubType === "4") + return createSpecializedChatWrapper(Gemma4ChatWrapper); + return createSpecializedChatWrapper(GemmaChatWrapper); - else if (splitLowercaseSubType.includes("chatml")) + } else if (splitLowercaseSubType.includes("chatml")) return createSpecializedChatWrapper(ChatMLChatWrapper); } } @@ -454,6 +464,8 @@ export function resolveChatWrapper( return createSpecializedChatWrapper(FalconChatWrapper); else if (arch === "gemma" || arch === "gemma2") return createSpecializedChatWrapper(GemmaChatWrapper); + else if (arch === "gemma4") + return createSpecializedChatWrapper(Gemma4ChatWrapper); } return null; diff --git a/src/gguf/insights/GgufInsights.ts b/src/gguf/insights/GgufInsights.ts index ed364c35..cb8b2211 100644 --- a/src/gguf/insights/GgufInsights.ts +++ b/src/gguf/insights/GgufInsights.ts @@ -248,29 +248,48 @@ export class GgufInsights { const tensorInfo = this._ggufFileInfo.fullTensorInfo ?? []; const slidingWindow = this.swaSize ?? 0; const kvUnified = false; - const usingSWA = !swaFullCache && slidingWindow > 0 && slidingWindow < contextSize && + const totalFileLayers = this._getTotalFileLayers(); + const hasSwaAttention = slidingWindow > 0; + const usingReducedSWA = hasSwaAttention && !swaFullCache && slidingWindow < contextSize && (this.trainContextSize == null || slidingWindow < this.trainContextSize); - const swaPattern = getSwaPatternForArchitecture( - this._ggufFileInfo.metadata?.general?.architecture, - this._ggufFileInfo.architectureMetadata?.attention?.sliding_window_pattern - ); - const nonSwaPercent = swaPattern <= 1 - ? 1 - : (1 / (swaPattern + (flashAttention ? -0.5 : -1))); + let graphRelevantTensorCount = 0; + let graphRelevantTensorElements = 0; + let totalTensorElements = 0; + + for (const singleTensorInfo of tensorInfo) { + let tensorElements = 0; + for (const dim of singleTensorInfo.dimensions) + tensorElements += Number(dim); + + totalTensorElements += tensorElements; + + if (!isGraphRelevantTensor(singleTensorInfo.name)) + continue; + + graphRelevantTensorCount++; + graphRelevantTensorElements += tensorElements; + } + + const effectiveGraphTensorCount = graphRelevantTensorCount > 0 + ? graphRelevantTensorCount + : tensorInfo.length; + const effectiveGraphTensorElements = graphRelevantTensorCount > 0 + ? graphRelevantTensorElements + : totalTensorElements; - // source: `llama_kv_cache_unified::get_padding` in `llama-kv-cache.cpp` - const kvCachePadding = 1; + const paddedContextSize = padSafeContextSize(contextSize, "up"); const actualContextSize = kvUnified ? padSafeContextSize(sequences * contextSize, "up") - : sequences * padSafeContextSize(contextSize, "up"); - const kvSize = usingSWA - ? ( - (1 - nonSwaPercent) * Math.min(actualContextSize, ggmlPad(sequences * slidingWindow + batchSize, kvCachePadding)) + - nonSwaPercent * actualContextSize - ) - : actualContextSize; - - const totalFileLayers = this._getTotalFileLayers(); + : sequences * paddedContextSize; + const fullAttentionKvSize = actualContextSize; + const swaBatchSize = hasSwaAttention && !swaFullCache + ? batchSize + 1 + : batchSize; + const swaKvSize = !hasSwaAttention + ? actualContextSize + : !usingReducedSWA + ? actualContextSize + : Math.min(actualContextSize, ggmlPad((sequences * slidingWindow) + swaBatchSize, 256)); const totalLayersIncludingOutput = totalFileLayers + 1; const finalModelGpuLayers = Math.max( 0, @@ -284,13 +303,17 @@ export class GgufInsights { gpuKVCacheSize, cpuKVCacheSize, gpuRecurrentStateSize, - cpuRecurrentStateSize + cpuRecurrentStateSize, + maxAttentionLayerKvSize, + maxAttentionLayerHeadCountKv } = this._estimateContextCacheMemorySplitInBytes({ - kvSize, + fullAttentionKvSize, + swaKvSize, sequences, totalFileLayers, finalModelGpuLayers, usingGpu, + flashAttention, kvCacheKeyType, kvCacheValueType }); @@ -324,37 +347,82 @@ export class GgufInsights { const estimateGraphOverheadMemory = (): number => { const s1MB = Math.pow(1024, 2); - const tensorInfo = this._ggufFileInfo.fullTensorInfo ?? []; const expertCount = llmData?.expert_count ?? 0; const headCount = llmData?.attention?.head_count ?? 0; const embeddingLength = llmData?.embedding_length ?? 0; + const activeGraphTokens = roundUpToMultiple( + Math.max(1, Math.min(paddedContextSize, batchSize)), + Math.max(1, sequences) + ); + const graphContextSize = resolveGraphContextSizeForOverheadEstimation({ + fullAttentionKvSize, + trainContextSize: this.trainContextSize, + flashAttention, + headCount, + batchSize, + paddedContextSize, + sequences + }); let defaultCalculationAdjustment = 0; + const totalElements = effectiveGraphTensorCount === 0 + ? this.totalLayers * ( + ( + (llmData.embedding_length ?? 0) + + (llmData.feed_forward_length ?? 0) + ) / 2 + ) + : effectiveGraphTensorElements; + const tensorBasedGraphOverhead = (tensorElementMultiplier: number) => ( + (totalElements * tensorElementMultiplier * (graphContextSize / 4096)) + defaultCalculationAdjustment + ); + const batchLocalTensorBasedGraphOverhead = (tensorElementMultiplier: number) => ( + (totalElements * tensorElementMultiplier * (activeGraphTokens / 4096)) + defaultCalculationAdjustment + ); if (batchSize == null) return 0; + const genericNonFlashAttentionWorkspaceEstimate = !flashAttention + ? estimateNonFlashAttentionWorkspace({ + trainContextSize: this.trainContextSize, + fullAttentionKvSize, + swaKvSize, + hasSwaAttention, + maxAttentionLayerKvSize, + maxAttentionLayerHeadCountKv, + activeGraphTokens, + headCount + }) + : 0; + if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.llama) { if (expertCount > 0) { const expertsUsedCount = this._ggufFileInfo.architectureMetadata.expert_used_count ?? 2; - return int32TBytes * batchSize * (((expertsUsedCount + 1) * embeddingLength) + (kvSize * headCount)); + return Math.max( + int32TBytes * batchSize * (((expertsUsedCount + 1) * embeddingLength) + (graphContextSize * headCount)), + genericNonFlashAttentionWorkspaceEstimate + ); } - return int32TBytes * batchSize * (embeddingLength + (kvSize * headCount)); + return Math.max( + int32TBytes * batchSize * (embeddingLength + (graphContextSize * headCount)), + genericNonFlashAttentionWorkspaceEstimate + ); } else if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.qwen2) { if (modelGpuLayers === this.totalLayers) { defaultCalculationAdjustment -= (s1MB * 340) * ( this.trainContextSize == null ? 1 - : kvSize / this.trainContextSize + : graphContextSize / this.trainContextSize ); } else { defaultCalculationAdjustment -= (s1MB * 250) + ( (s1MB * 50) * ( this.trainContextSize == null ? 1 - : kvSize / this.trainContextSize + : graphContextSize / this.trainContextSize ) ); } @@ -367,7 +435,7 @@ export class GgufInsights { (s1MB * 270) * ( this.trainContextSize == null ? 1 - : kvSize / this.trainContextSize + : graphContextSize / this.trainContextSize ) ); } else { @@ -375,14 +443,25 @@ export class GgufInsights { (s1MB * 150) * ( this.trainContextSize == null ? 1 - : Math.max(0, (1 - (kvSize / this.trainContextSize))) + : Math.max(0, (1 - (graphContextSize / this.trainContextSize))) ) ); } + } else if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.gemma3) { + const trainContextSize = Math.max(1, this.trainContextSize ?? graphContextSize); + const contextRatio = Math.min(1, Math.max(0, graphContextSize / trainContextSize)); + + return Math.max( + int32TBytes * batchSize * graphContextSize * headCount * (0.08 + Math.pow(contextRatio, 2)), + genericNonFlashAttentionWorkspaceEstimate + ); } else if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.stablelm) { const headCount = this._ggufFileInfo.architectureMetadata.attention?.head_count ?? 0; - return (int32TBytes * batchSize * kvSize * headCount) - (50 * s1MB); + return Math.max( + (int32TBytes * batchSize * graphContextSize * headCount) - (50 * s1MB), + genericNonFlashAttentionWorkspaceEstimate + ); // if (modelGpuLayers === this.totalLayers) { // defaultCalculationAdjustment += -(s1MB * 20) + ( @@ -402,34 +481,58 @@ export class GgufInsights { // ); // } } else if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.qwen3) { - return int32TBytes * batchSize * (embeddingLength + (kvSize * headCount)); + return Math.max( + int32TBytes * batchSize * (embeddingLength + (graphContextSize * headCount)), + genericNonFlashAttentionWorkspaceEstimate + ); + } else if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.gemma4) { + const trainContextSize = Math.max(1, this.trainContextSize ?? graphContextSize); + const contextRatio = Math.min(1, Math.max(0, graphContextSize / trainContextSize)); + const gemma4DenseShortContextScale = 0.4; + const gemma4DenseContextScaleExponent = 3; + const gemma4DenseEstimate = int32TBytes * batchSize * graphContextSize * headCount * + (gemma4DenseShortContextScale + Math.pow(contextRatio, gemma4DenseContextScaleExponent)); + + // Gemma 4 non-FA contexts allocate noticeably less temporary attention workspace than the + // generic heuristic predicts for the dense models. The MoE variants keep a larger + // tensor-driven graph reserve, so blend the dense fit with the tensor-size heuristic. + // `gemma4DenseShortContextScale` is a calibration floor from `inspect measure` runs for dense Gemma 4 models. + if (expertCount > 0) { + const tensorBasedEstimate = tensorBasedGraphOverhead(77.655); + const moeBlendWeight = Math.sqrt(contextRatio); + + return Math.max( + gemma4DenseEstimate, + (gemma4DenseEstimate + ((tensorBasedEstimate - gemma4DenseEstimate) * moeBlendWeight)) * 1.01, + genericNonFlashAttentionWorkspaceEstimate + ); + } + + return Math.max(gemma4DenseEstimate, genericNonFlashAttentionWorkspaceEstimate); } else if (expertCount > 0) { const expertsUsedCount = this._ggufFileInfo.architectureMetadata.expert_used_count ?? 2; - return int32TBytes * batchSize * (((expertsUsedCount + 1) * embeddingLength) + (kvSize * headCount)); + return Math.max( + int32TBytes * batchSize * (((expertsUsedCount + 1) * embeddingLength) + (graphContextSize * headCount)), + genericNonFlashAttentionWorkspaceEstimate + ); } - const totalElements = tensorInfo.length === 0 - ? this.totalLayers * ( - ( - (llmData.embedding_length ?? 0) + - (llmData.feed_forward_length ?? 0) - ) / 2 - ) - : tensorInfo.reduce((res, tensor) => { - return res + tensor.dimensions.reduce((res: number, dim) => res + Number(dim), 0); - }, 0); - if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.phi3) { // magic numbers for estimation. will be improved in the future - return (totalElements * 123 * (kvSize / 4096)) + defaultCalculationAdjustment; + return Math.max(tensorBasedGraphOverhead(123), genericNonFlashAttentionWorkspaceEstimate); } else if (this._ggufFileInfo.metadata.general?.architecture === GgufArchitectureType.cohere2) { // magic numbers for estimation. will be improved in the future - return (totalElements * 148 * (kvSize / 4096)) + defaultCalculationAdjustment; + return Math.max(tensorBasedGraphOverhead(148), genericNonFlashAttentionWorkspaceEstimate); } // magic numbers for estimation. will be improved in the future - return (totalElements * 77.655 * (kvSize / 4096)) + defaultCalculationAdjustment; + return Math.max( + !flashAttention + ? batchLocalTensorBasedGraphOverhead(77.655) + : tensorBasedGraphOverhead(77.655), + genericNonFlashAttentionWorkspaceEstimate + ); }; // source: `llama_context::graph_max_nodes` in `llama-context.cpp` @@ -449,10 +552,10 @@ export class GgufInsights { this._ggufFileInfo.metadata?.general?.architecture, Math.min(actualContextSize, batchSize) ); - const maxNodes = Math.max(maxNodesMultiplier.min, maxNodesMultiplier.multiplier * tensorInfo.length); + const maxNodes = Math.max(maxNodesMultiplier.min, maxNodesMultiplier.multiplier * effectiveGraphTensorCount); const cpuNodes = totalFileLayers === 0 ? 0 - : maxNodesMultiplier.multiplier * (tensorInfo.length * (finalCpuLayers / totalFileLayers)); + : maxNodesMultiplier.multiplier * (effectiveGraphTensorCount * (finalCpuLayers / totalFileLayers)); const gpuNodes = maxNodes - cpuNodes; const gpuComputeBufferSize = (this._llama._consts.ggmlTensorOverhead * gpuNodes) + @@ -460,7 +563,7 @@ export class GgufInsights { const cpuComputeBufferSize = (this._llama._consts.ggmlTensorOverhead * cpuNodes) + this._llama._bindings.getGgmlGraphOverheadCustom(cpuNodes, false); - const graphOverheadMemory = (flashAttention || !includeGraphOverhead) + const graphOverheadMemory = !includeGraphOverhead ? 0 : estimateGraphOverheadMemory(); const graphOverheadGpuSize = (usingGpu && totalFileLayers > 0) @@ -580,19 +683,23 @@ export class GgufInsights { } private _estimateContextCacheMemorySplitInBytes({ - kvSize, + fullAttentionKvSize, + swaKvSize, sequences, totalFileLayers, finalModelGpuLayers, usingGpu, + flashAttention, kvCacheKeyType = GgmlType.F16, kvCacheValueType = GgmlType.F16 }: { - kvSize: number, + fullAttentionKvSize: number, + swaKvSize: number, sequences: number, totalFileLayers: number, finalModelGpuLayers: number, usingGpu: boolean, + flashAttention: boolean, kvCacheKeyType?: GgmlType, kvCacheValueType?: GgmlType }) { @@ -603,8 +710,24 @@ export class GgufInsights { const nEmbdHeadK = this._ggufFileInfo.architectureMetadata.attention?.key_length ?? ((nHead == 0) ? 0 : (nEmbd / nHead)); const nHeadKv: number | number[] = this._ggufFileInfo.architectureMetadata.attention?.head_count_kv ?? nHead; const nEmbdHeadV = this._ggufFileInfo.architectureMetadata.attention?.value_length ?? ((nHead == 0) ? 0 : nEmbd / nHead); + const nEmbdHeadKSwa = this._ggufFileInfo.architectureMetadata.attention?.key_length_swa; + const nEmbdHeadVSwa = this._ggufFileInfo.architectureMetadata.attention?.value_length_swa; + const sharedKvLayers = this._ggufFileInfo.architectureMetadata.attention?.shared_kv_layers; + const slidingWindowPattern = this._ggufFileInfo.architectureMetadata.attention?.sliding_window_pattern; const keyTypeSize = this._llama._bindings.getTypeSizeForGgmlType(kvCacheKeyType) ?? this._llama._consts.ggmlTypeF16Size; const valueTypeSize = this._llama._bindings.getTypeSizeForGgmlType(kvCacheValueType) ?? this._llama._consts.ggmlTypeF16Size; + const nHeadKvValues = nHeadKv as unknown; + let maxLayerValueEmbedding = 0; + + if (!flashAttention && nHeadKvValues instanceof Array) { + for (let i = 0; i < totalFileLayers; i++) { + const layerHeadCountKv = resolveLayerHeadCountKv(nHeadKvValues, i, nHead); + const isSwaLayer = isSwaLayerAtIndex(architecture, slidingWindowPattern, i); + const layerValueEmbedding = resolveLayerHeadDimension(nEmbdHeadV, nEmbdHeadVSwa, isSwaLayer) * layerHeadCountKv; + + maxLayerValueEmbedding = Math.max(maxLayerValueEmbedding, layerValueEmbedding); + } + } // source: `llama_model::load_tensors` in `llama-model.cpp` // repeating layers are assigned to GPU from `i_gpu_start = n_layer + 1 - n_gpu_layers` @@ -618,6 +741,8 @@ export class GgufInsights { let cpuKvElementsV = 0; let gpuRecurrentLayers = 0; let cpuRecurrentLayers = 0; + let maxAttentionLayerKvSize = 0; + let maxAttentionLayerHeadCountKv = 0; for (let i = 0; i < totalFileLayers; i++) { const isGpuLayer = i >= gpuRepeatingLayerStart; @@ -629,9 +754,22 @@ export class GgufInsights { else cpuRecurrentLayers++; } else { + if (!doesLayerOwnKvCache(totalFileLayers, i, sharedKvLayers)) + continue; + const nHeadKvLayer = resolveLayerHeadCountKv(nHeadKv, i, nHead); - const layerElementsK = nEmbdHeadK * nHeadKvLayer * kvSize; - const layerElementsV = nEmbdHeadV * nHeadKvLayer * kvSize; + const isSwaLayer = isSwaLayerAtIndex(architecture, slidingWindowPattern, i); + const layerKvSize = isSwaLayer + ? swaKvSize + : fullAttentionKvSize; + maxAttentionLayerKvSize = Math.max(maxAttentionLayerKvSize, layerKvSize); + maxAttentionLayerHeadCountKv = Math.max(maxAttentionLayerHeadCountKv, nHeadKvLayer); + const layerElementsK = resolveLayerHeadDimension(nEmbdHeadK, nEmbdHeadKSwa, isSwaLayer) * nHeadKvLayer * layerKvSize; + const layerElementsV = layerKvSize * ( + maxLayerValueEmbedding > 0 + ? maxLayerValueEmbedding + : (resolveLayerHeadDimension(nEmbdHeadV, nEmbdHeadVSwa, isSwaLayer) * nHeadKvLayer) + ); if (isGpuLayer) { gpuKvElementsK += layerElementsK; @@ -658,7 +796,9 @@ export class GgufInsights { gpuKVCacheSize, cpuKVCacheSize, gpuRecurrentStateSize, - cpuRecurrentStateSize + cpuRecurrentStateSize, + maxAttentionLayerKvSize, + maxAttentionLayerHeadCountKv }; } @@ -941,49 +1081,212 @@ function isTokenEmbedLayer(layerName: string) { return firstPart === "token_embd"; } +function isGraphRelevantTensor(tensorName: string): boolean { + return isInputLayer(tensorName) || + isOutputLayer(tensorName) || + tensorName.startsWith("blk.") || + tensorName.startsWith("enc.blk.") || + tensorName.startsWith("dec.blk."); +} + function ggmlPad(value: number, padding: number): number { return ((value + padding - 1) & ~(padding - 1)); } -function getSwaPatternForArchitecture(architecture?: GgufArchitectureType, slidingWindowPattern?: number | number[]): number { - if (typeof slidingWindowPattern === "number") - return slidingWindowPattern; +function roundUpToMultiple(value: number, multiple: number): number { + if (multiple <= 1) + return value; + + return Math.ceil(value / multiple) * multiple; +} + +function resolveGraphContextSizeForOverheadEstimation({ + fullAttentionKvSize, + trainContextSize, + flashAttention, + headCount, + batchSize, + paddedContextSize, + sequences +}: { + fullAttentionKvSize: number, + trainContextSize: number | undefined, + flashAttention: boolean, + headCount: number, + batchSize: number, + paddedContextSize: number, + sequences: number +}) { + // heuristic coefficients fit to estimate llama.cpp graph-reserve behavior + const flashAttentionMinContextMultiplier = 0.5; + const flashAttentionMaxContextMultiplier = 0.78; + const flashAttentionMinHeadCountForScaling = 4; + const flashAttentionContextRatioLog2Cap = 2; + const flashAttentionContextRatioLog2Scale = 0.05; + const longContextOverflowStartRatio = 1.25; + const longContextOverflowGrowthScale = 0.1; + const longContextMaxMultiplierIncrease = 0.4; + + const normalizedTrainContextSize = trainContextSize == null || trainContextSize <= 0 + ? Math.max(1, fullAttentionKvSize) + : trainContextSize; + const contextRatio = Math.max(1, fullAttentionKvSize / normalizedTrainContextSize); + + if (flashAttention) { + const activeGraphTokens = roundUpToMultiple( + Math.max(1, Math.min(paddedContextSize, batchSize)), + Math.max(1, sequences) + ); + const flashContextMultiplierBase = + flashAttentionMinContextMultiplier + (1 / Math.max(flashAttentionMinHeadCountForScaling, headCount)); + const flashContextMultiplierLongContextAdjustment = + Math.min(flashAttentionContextRatioLog2Cap, Math.log2(contextRatio)) * flashAttentionContextRatioLog2Scale; + const flashContextMultiplier = Math.max( + flashAttentionMinContextMultiplier, + Math.min( + flashAttentionMaxContextMultiplier, + flashContextMultiplierBase + flashContextMultiplierLongContextAdjustment + ) + ); + + return activeGraphTokens * flashContextMultiplier; + } + + const contextOverflow = Math.max(0, contextRatio - longContextOverflowStartRatio); + const longContextMultiplier = 1 + Math.min( + longContextMaxMultiplierIncrease, + longContextOverflowGrowthScale * contextOverflow * contextOverflow + ); + + return fullAttentionKvSize * longContextMultiplier; +} + +function estimateNonFlashAttentionWorkspace({ + trainContextSize, + fullAttentionKvSize, + swaKvSize, + hasSwaAttention, + maxAttentionLayerKvSize, + maxAttentionLayerHeadCountKv, + activeGraphTokens, + headCount +}: { + trainContextSize: number | undefined, + fullAttentionKvSize: number, + swaKvSize: number, + hasSwaAttention: boolean, + maxAttentionLayerKvSize: number, + maxAttentionLayerHeadCountKv: number, + activeGraphTokens: number, + headCount: number +}) { + const floatBytes = 4; // sizeof(float) + const strongGqaMaxKvToQHeadRatio = 0.5; + const minAttentionScoreWorkspaceScale = 0.4; + const additionalAttentionScoreWorkspaceScale = 0.6; + + if (maxAttentionLayerKvSize <= 0 || activeGraphTokens <= 0 || headCount <= 0) + return 0; + + const attentionScoresWorkspace = floatBytes * activeGraphTokens * maxAttentionLayerKvSize * headCount; + const attentionMaskWorkspace = floatBytes * activeGraphTokens * ( + hasSwaAttention + ? fullAttentionKvSize + swaKvSize + : maxAttentionLayerKvSize + ); + if (!hasSwaAttention) + // source: non-FA reserve path in `llm_graph_context::build_attn_mha` + `build_attn_inp_kq_mask` in `llama-graph.cpp` + // reserves the full KQ tensor and the matching F32 attention mask for the ubatch-local graph + return attentionScoresWorkspace + attentionMaskWorkspace; + + // the explicit KQ workspace floor matches the non-FA reserve path well for MHA-like layouts, + // but it becomes too aggressive for strong GQA / MQA hybrid models where KV heads are much fewer than Q heads + if (maxAttentionLayerHeadCountKv / headCount < strongGqaMaxKvToQHeadRatio) + return attentionMaskWorkspace; + + const normalizedTrainContextSize = Math.max(1, trainContextSize ?? maxAttentionLayerKvSize); + const contextRatio = Math.min(1, Math.max(0, maxAttentionLayerKvSize / normalizedTrainContextSize)); + const attentionScoreWorkspaceScale = + minAttentionScoreWorkspaceScale + (additionalAttentionScoreWorkspaceScale * contextRatio); + + return (attentionScoresWorkspace * attentionScoreWorkspaceScale) + attentionMaskWorkspace; +} + +function isSwaLayerAtIndex( + architecture: GgufArchitectureType | undefined, + slidingWindowPattern: number | number[] | undefined, + layerIndex: number +): boolean { + if (layerIndex < 0) + return false; + + if (slidingWindowPattern instanceof Array) + return Boolean(slidingWindowPattern[layerIndex]); + + const [defaultPattern, denseFirst] = getSwaPatternForArchitecture(architecture); + const pattern = typeof slidingWindowPattern === "number" + ? Math.max(0, Math.floor(slidingWindowPattern)) + : defaultPattern; + + if (pattern === 0) + return true; + + return denseFirst + ? (layerIndex % pattern !== 0) + : (layerIndex % pattern < (pattern - 1)); +} + +function getSwaPatternForArchitecture(architecture?: GgufArchitectureType): [pattern: number, denseFirst: boolean] { // source: `llama_model::load_hparams` in `llama-model.cpp` - calls to `hparams.set_swa_pattern` switch (architecture) { case GgufArchitectureType.llama4: - return 4; + return [4, false]; case GgufArchitectureType.afmoe: - return 4; + return [4, false]; case GgufArchitectureType.modernBert: - return 3; + return [3, true]; case GgufArchitectureType.phi3: - return 1; + return [1, false]; case GgufArchitectureType.plamo3: - return 8; + return [8, false]; case GgufArchitectureType.gemma2: - return 2; + return [2, false]; case GgufArchitectureType.gemma3: - return 6; + return [6, false]; case GgufArchitectureType.gemma3n: - return 5; + return [5, false]; case GgufArchitectureType.gemmaEmbedding: - return 6; + return [6, false]; case GgufArchitectureType.cohere2: - return 4; + return [4, false]; case GgufArchitectureType.olmo2: - return 4; + return [4, false]; case GgufArchitectureType.exaone4: - return 4; + return [4, false]; case GgufArchitectureType.exaoneMoe: - return 4; + return [4, false]; case GgufArchitectureType.gptOss: - return 2; + return [2, false]; case GgufArchitectureType.smallthinker: - return 4; + return [4, true]; } - return 1; + return [1, false]; +} + +function resolveLayerHeadDimension(defaultValue: number, swaValue: number | undefined, isSwaLayer: boolean): number { + if (isSwaLayer && swaValue != null) + return swaValue; + + return defaultValue; +} + +function doesLayerOwnKvCache(totalLayers: number, layerIndex: number, sharedKvLayers: number | undefined): boolean { + if (sharedKvLayers == null || sharedKvLayers <= 0) + return true; + + return layerIndex < Math.max(0, totalLayers - sharedKvLayers); } function resolveLayerHeadCountKv(nHeadKv: number | number[], layerIndex: number, nHead: number): number { @@ -1007,8 +1310,9 @@ function getRecurrentLayersPattern( architectureMetadata: GgufFileInfo["architectureMetadata"] ): RecurrentLayersPattern { const nHeadKv = architectureMetadata?.attention?.head_count_kv; + const nHeadKvValues: number | number[] | undefined = nHeadKv; const feedForwardLength = architectureMetadata?.feed_forward_length as number | number[] | undefined; - const hasRecurrentHeadCountKvEntry = Array.isArray(nHeadKv) && nHeadKv.some((value) => value === 0); + const hasRecurrentHeadCountKvEntry = nHeadKvValues instanceof Array && nHeadKvValues.some((value: number) => value === 0); if (architecture === GgufArchitectureType.falconH1) // source: `llama_model::load_hparams` in `llama-model.cpp`: @@ -1019,10 +1323,10 @@ function getRecurrentLayersPattern( // source: `llama_model::load_hparams` in `llama-model.cpp`: // `case LLM_ARCH_NEMOTRON_H / LLM_ARCH_NEMOTRON_H_MOE`: // `recurrent_layer_arr[i] = (n_head_kv(i) == 0 && n_ff(i) == 0)` - if (Array.isArray(nHeadKv)) + if (nHeadKvValues instanceof Array) return { type: "headCountKvAndFeedForward", - headCountKvValues: nHeadKv, + headCountKvValues: nHeadKvValues, feedForwardLength }; @@ -1055,10 +1359,10 @@ function getRecurrentLayersPattern( interval: Math.max(1, Math.floor(architectureMetadata?.full_attention_interval)) }; - if (hasRecurrentHeadCountKvEntry) + if (nHeadKvValues instanceof Array && hasRecurrentHeadCountKvEntry) return { type: "headCountKvArray", - values: nHeadKv + values: nHeadKvValues }; return "none"; @@ -1081,7 +1385,7 @@ function isLayerRecurrent(pattern: RecurrentLayersPattern, layerIndex: number): function resolveLayerFeedForwardLength(feedForwardLength: number | number[] | undefined, layerIndex: number): number { if (typeof feedForwardLength === "number") return feedForwardLength; - else if (Array.isArray(feedForwardLength)) + else if (feedForwardLength instanceof Array) return feedForwardLength[layerIndex] ?? 0; return 0; diff --git a/src/gguf/types/GgufMetadataTypes.ts b/src/gguf/types/GgufMetadataTypes.ts index 249cad32..e290dd7c 100644 --- a/src/gguf/types/GgufMetadataTypes.ts +++ b/src/gguf/types/GgufMetadataTypes.ts @@ -47,6 +47,7 @@ export const enum GgufArchitectureType { gemma2 = "gemma2", gemma3 = "gemma3", gemma3n = "gemma3n", + gemma4 = "gemma4", gemmaEmbedding = "gemma-embedding", starcoder2 = "starcoder2", mamba = "mamba", diff --git a/src/index.ts b/src/index.ts index 66d254fb..09535736 100644 --- a/src/index.ts +++ b/src/index.ts @@ -62,6 +62,7 @@ import {FalconChatWrapper} from "./chatWrappers/FalconChatWrapper.js"; import {AlpacaChatWrapper} from "./chatWrappers/AlpacaChatWrapper.js"; import {FunctionaryChatWrapper} from "./chatWrappers/FunctionaryChatWrapper.js"; import {GemmaChatWrapper} from "./chatWrappers/GemmaChatWrapper.js"; +import {Gemma4ChatWrapper} from "./chatWrappers/Gemma4ChatWrapper.js"; import {HarmonyChatWrapper} from "./chatWrappers/HarmonyChatWrapper.js"; import {TemplateChatWrapper, type TemplateChatWrapperOptions} from "./chatWrappers/generic/TemplateChatWrapper.js"; import { @@ -231,6 +232,7 @@ export { AlpacaChatWrapper, FunctionaryChatWrapper, GemmaChatWrapper, + Gemma4ChatWrapper, HarmonyChatWrapper, TemplateChatWrapper, type TemplateChatWrapperOptions, diff --git a/src/utils/LlamaText.ts b/src/utils/LlamaText.ts index 6675b762..70ba0ec2 100644 --- a/src/utils/LlamaText.ts +++ b/src/utils/LlamaText.ts @@ -122,6 +122,10 @@ class LlamaText { return LlamaTextConstructor.compare(this, other); } + public trim(): LlamaText { + return this.trimStart().trimEnd(); + } + public trimStart(): LlamaText { const newValues = this.values.slice();