From c855cd7cdd2150fa08d6046abb3d03d498e5a82e Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Mon, 6 Apr 2026 09:47:19 +0100 Subject: [PATCH 1/2] gal meta-data --- core/config/backend_capabilities.go | 458 ++++++++++++++++++ core/config/backend_capabilities_test.go | 116 +++++ core/config/model_config.go | 29 ++ core/gallery/gallery.go | 93 ++++ core/gallery/importers/diffuser.go | 2 +- core/gallery/importers/llama-cpp.go | 2 +- core/gallery/importers/local.go | 8 +- core/gallery/importers/mlx.go | 2 +- core/gallery/importers/transformers.go | 2 +- core/gallery/importers/vllm.go | 2 +- core/gallery/models_types.go | 20 + core/http/endpoints/localai/config_meta.go | 8 +- core/http/react-ui/e2e/models-gallery.spec.js | 124 ++++- core/http/react-ui/src/pages/Backends.jsx | 4 +- core/http/react-ui/src/pages/Models.jsx | 130 +++-- core/http/react-ui/src/utils/api.js | 2 + core/http/react-ui/src/utils/config.js | 1 + core/http/routes/ui_api.go | 231 +++++---- pkg/vram/cache.go | 41 ++ 19 files changed, 1129 insertions(+), 146 deletions(-) create mode 100644 core/config/backend_capabilities.go create mode 100644 core/config/backend_capabilities_test.go diff --git a/core/config/backend_capabilities.go b/core/config/backend_capabilities.go new file mode 100644 index 000000000000..02af395545be --- /dev/null +++ b/core/config/backend_capabilities.go @@ -0,0 +1,458 @@ +package config + +import ( + "slices" + "strings" +) + +// Usecase name constants — the canonical string values used in gallery entries, +// model configs (known_usecases), and UsecaseInfoMap keys. +const ( + UsecaseChat = "chat" + UsecaseCompletion = "completion" + UsecaseEdit = "edit" + UsecaseVision = "vision" + UsecaseEmbeddings = "embeddings" + UsecaseTokenize = "tokenize" + UsecaseImage = "image" + UsecaseVideo = "video" + UsecaseTranscript = "transcript" + UsecaseTTS = "tts" + UsecaseSoundGeneration = "sound_generation" + UsecaseRerank = "rerank" + UsecaseDetection = "detection" + UsecaseVAD = "vad" +) + +// GRPCMethod identifies a Backend service RPC from backend.proto. +type GRPCMethod string + +const ( + MethodPredict GRPCMethod = "Predict" + MethodPredictStream GRPCMethod = "PredictStream" + MethodEmbedding GRPCMethod = "Embedding" + MethodGenerateImage GRPCMethod = "GenerateImage" + MethodGenerateVideo GRPCMethod = "GenerateVideo" + MethodAudioTranscription GRPCMethod = "AudioTranscription" + MethodTTS GRPCMethod = "TTS" + MethodTTSStream GRPCMethod = "TTSStream" + MethodSoundGeneration GRPCMethod = "SoundGeneration" + MethodTokenizeString GRPCMethod = "TokenizeString" + MethodDetect GRPCMethod = "Detect" + MethodRerank GRPCMethod = "Rerank" + MethodVAD GRPCMethod = "VAD" +) + +// UsecaseInfo describes a single known_usecase value and how it maps +// to the gRPC backend API. +type UsecaseInfo struct { + // Flag is the ModelConfigUsecase bitmask value. + Flag ModelConfigUsecase + // GRPCMethod is the primary Backend service RPC this usecase maps to. + GRPCMethod GRPCMethod + // IsModifier is true when this usecase doesn't map to its own gRPC RPC + // but modifies how another RPC behaves (e.g., vision uses Predict with images). + IsModifier bool + // DependsOn names the usecase(s) this modifier requires (e.g., "chat"). + DependsOn string + // Description is a human/LLM-readable explanation of what this usecase means. + Description string +} + +// UsecaseInfoMap maps each known_usecase string to its gRPC and semantic info. +var UsecaseInfoMap = map[string]UsecaseInfo{ + UsecaseChat: { + Flag: FLAG_CHAT, + GRPCMethod: MethodPredict, + Description: "Conversational/instruction-following via the Predict RPC with chat templates.", + }, + UsecaseCompletion: { + Flag: FLAG_COMPLETION, + GRPCMethod: MethodPredict, + Description: "Text completion via the Predict RPC with a completion template.", + }, + UsecaseEdit: { + Flag: FLAG_EDIT, + GRPCMethod: MethodPredict, + Description: "Text editing via the Predict RPC with an edit template.", + }, + UsecaseVision: { + Flag: FLAG_VISION, + GRPCMethod: MethodPredict, + IsModifier: true, + DependsOn: UsecaseChat, + Description: "The model accepts images alongside text in the Predict RPC. For llama-cpp this requires an mmproj file.", + }, + UsecaseEmbeddings: { + Flag: FLAG_EMBEDDINGS, + GRPCMethod: MethodEmbedding, + Description: "Vector embedding generation via the Embedding RPC.", + }, + UsecaseTokenize: { + Flag: FLAG_TOKENIZE, + GRPCMethod: MethodTokenizeString, + Description: "Tokenization via the TokenizeString RPC without running inference.", + }, + UsecaseImage: { + Flag: FLAG_IMAGE, + GRPCMethod: MethodGenerateImage, + Description: "Image generation via the GenerateImage RPC (Stable Diffusion, Flux, etc.).", + }, + UsecaseVideo: { + Flag: FLAG_VIDEO, + GRPCMethod: MethodGenerateVideo, + Description: "Video generation via the GenerateVideo RPC.", + }, + UsecaseTranscript: { + Flag: FLAG_TRANSCRIPT, + GRPCMethod: MethodAudioTranscription, + Description: "Speech-to-text via the AudioTranscription RPC.", + }, + UsecaseTTS: { + Flag: FLAG_TTS, + GRPCMethod: MethodTTS, + Description: "Text-to-speech via the TTS RPC.", + }, + UsecaseSoundGeneration: { + Flag: FLAG_SOUND_GENERATION, + GRPCMethod: MethodSoundGeneration, + Description: "Music/sound generation via the SoundGeneration RPC (not speech).", + }, + UsecaseRerank: { + Flag: FLAG_RERANK, + GRPCMethod: MethodRerank, + Description: "Document reranking via the Rerank RPC.", + }, + UsecaseDetection: { + Flag: FLAG_DETECTION, + GRPCMethod: MethodDetect, + Description: "Object detection via the Detect RPC with bounding boxes.", + }, + UsecaseVAD: { + Flag: FLAG_VAD, + GRPCMethod: MethodVAD, + Description: "Voice activity detection via the VAD RPC.", + }, +} + +// BackendCapability describes which gRPC methods and usecases a backend supports. +// Derived from reviewing actual implementations in backend/go/ and backend/python/. +type BackendCapability struct { + // GRPCMethods lists the Backend service RPCs this backend implements. + GRPCMethods []GRPCMethod + // PossibleUsecases lists all usecase strings this backend can support. + PossibleUsecases []string + // DefaultUsecases lists the conservative safe defaults. + DefaultUsecases []string + // AcceptsImages indicates multimodal image input in Predict. + AcceptsImages bool + // AcceptsVideos indicates multimodal video input in Predict. + AcceptsVideos bool + // AcceptsAudios indicates multimodal audio input in Predict. + AcceptsAudios bool + // Description is a human-readable summary of the backend. + Description string +} + +// BackendCapabilities maps each backend name (as used in model configs and gallery +// entries) to its verified capabilities. This is the single source of truth for +// what each backend supports. +// +// Backend names use hyphens (e.g., "llama-cpp") matching the gallery convention. +// Use NormalizeBackendName() for names with dots (e.g., "llama.cpp"). +var BackendCapabilities = map[string]BackendCapability{ + // --- LLM / text generation backends --- + "llama-cpp": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding, MethodTokenizeString}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEdit, UsecaseEmbeddings, UsecaseTokenize, UsecaseVision}, + DefaultUsecases: []string{UsecaseChat}, + AcceptsImages: true, // requires mmproj + Description: "llama.cpp GGUF models — LLM inference with optional vision via mmproj", + }, + "vllm": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseVision}, + DefaultUsecases: []string{UsecaseChat}, + AcceptsImages: true, + AcceptsVideos: true, + Description: "vLLM engine — high-throughput LLM serving with optional multimodal", + }, + "vllm-omni": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodGenerateImage, MethodGenerateVideo, MethodTTS}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseImage, UsecaseVideo, UsecaseTTS, UsecaseVision}, + DefaultUsecases: []string{UsecaseChat}, + AcceptsImages: true, + AcceptsVideos: true, + AcceptsAudios: true, + Description: "vLLM omni-modal — supports text, image, video generation and TTS", + }, + "transformers": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding, MethodTTS, MethodSoundGeneration}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseTTS, UsecaseSoundGeneration}, + DefaultUsecases: []string{UsecaseChat}, + Description: "HuggingFace transformers — general-purpose Python inference", + }, + "mlx": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings}, + DefaultUsecases: []string{UsecaseChat}, + Description: "Apple MLX framework — optimized for Apple Silicon", + }, + "mlx-distributed": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings}, + DefaultUsecases: []string{UsecaseChat}, + Description: "MLX distributed inference across multiple Apple Silicon devices", + }, + "mlx-vlm": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseVision}, + DefaultUsecases: []string{UsecaseChat, UsecaseVision}, + AcceptsImages: true, + AcceptsAudios: true, + Description: "MLX vision-language models with multimodal input", + }, + "mlx-audio": { + GRPCMethods: []GRPCMethod{MethodPredict, MethodTTS}, + PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseTTS}, + DefaultUsecases: []string{UsecaseChat}, + Description: "MLX audio models — text generation and TTS", + }, + + // --- Image/video generation backends --- + "diffusers": { + GRPCMethods: []GRPCMethod{MethodGenerateImage, MethodGenerateVideo}, + PossibleUsecases: []string{UsecaseImage, UsecaseVideo}, + DefaultUsecases: []string{UsecaseImage}, + Description: "HuggingFace diffusers — Stable Diffusion, Flux, video generation", + }, + "stablediffusion": { + GRPCMethods: []GRPCMethod{MethodGenerateImage}, + PossibleUsecases: []string{UsecaseImage}, + DefaultUsecases: []string{UsecaseImage}, + Description: "Stable Diffusion native backend", + }, + "stablediffusion-ggml": { + GRPCMethods: []GRPCMethod{MethodGenerateImage}, + PossibleUsecases: []string{UsecaseImage}, + DefaultUsecases: []string{UsecaseImage}, + Description: "Stable Diffusion via GGML quantized models", + }, + + // --- Speech-to-text backends --- + "whisper": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodVAD}, + PossibleUsecases: []string{UsecaseTranscript, UsecaseVAD}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "OpenAI Whisper — speech recognition and voice activity detection", + }, + "faster-whisper": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "CTranslate2-accelerated Whisper for faster transcription", + }, + "whisperx": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "WhisperX — Whisper with word-level timestamps and speaker diarization", + }, + "moonshine": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "Moonshine speech recognition", + }, + "nemo": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "NVIDIA NeMo speech recognition", + }, + "qwen-asr": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "Qwen automatic speech recognition", + }, + "voxtral": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription}, + PossibleUsecases: []string{UsecaseTranscript}, + DefaultUsecases: []string{UsecaseTranscript}, + Description: "Voxtral speech recognition", + }, + "vibevoice": { + GRPCMethods: []GRPCMethod{MethodAudioTranscription, MethodTTS}, + PossibleUsecases: []string{UsecaseTranscript, UsecaseTTS}, + DefaultUsecases: []string{UsecaseTranscript, UsecaseTTS}, + Description: "VibeVoice — bidirectional speech (transcription and synthesis)", + }, + + // --- TTS backends --- + "piper": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Piper — fast neural TTS optimized for Raspberry Pi", + }, + "kokoro": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Kokoro TTS", + }, + "coqui": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Coqui TTS — multi-speaker neural synthesis", + }, + "kitten-tts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Kitten TTS", + }, + "outetts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "OuteTTS", + }, + "pocket-tts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Pocket TTS — lightweight text-to-speech", + }, + "qwen-tts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Qwen TTS", + }, + "faster-qwen3-tts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Faster Qwen3 TTS — accelerated Qwen TTS", + }, + "fish-speech": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Fish Speech TTS", + }, + "neutts": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "NeuTTS — neural text-to-speech", + }, + "chatterbox": { + GRPCMethods: []GRPCMethod{MethodTTS}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "Chatterbox TTS", + }, + "voxcpm": { + GRPCMethods: []GRPCMethod{MethodTTS, MethodTTSStream}, + PossibleUsecases: []string{UsecaseTTS}, + DefaultUsecases: []string{UsecaseTTS}, + Description: "VoxCPM TTS with streaming support", + }, + + // --- Sound generation backends --- + "ace-step": { + GRPCMethods: []GRPCMethod{MethodTTS, MethodSoundGeneration}, + PossibleUsecases: []string{UsecaseTTS, UsecaseSoundGeneration}, + DefaultUsecases: []string{UsecaseSoundGeneration}, + Description: "ACE-Step — music and sound generation", + }, + "acestep-cpp": { + GRPCMethods: []GRPCMethod{MethodSoundGeneration}, + PossibleUsecases: []string{UsecaseSoundGeneration}, + DefaultUsecases: []string{UsecaseSoundGeneration}, + Description: "ACE-Step C++ — native sound generation", + }, + "transformers-musicgen": { + GRPCMethods: []GRPCMethod{MethodTTS, MethodSoundGeneration}, + PossibleUsecases: []string{UsecaseTTS, UsecaseSoundGeneration}, + DefaultUsecases: []string{UsecaseSoundGeneration}, + Description: "Meta MusicGen via transformers — music generation from text", + }, + + // --- Utility backends --- + "rerankers": { + GRPCMethods: []GRPCMethod{MethodRerank}, + PossibleUsecases: []string{UsecaseRerank}, + DefaultUsecases: []string{UsecaseRerank}, + Description: "Cross-encoder reranking models", + }, + "rfdetr": { + GRPCMethods: []GRPCMethod{MethodDetect}, + PossibleUsecases: []string{UsecaseDetection}, + DefaultUsecases: []string{UsecaseDetection}, + Description: "RF-DETR object detection", + }, + "silero-vad": { + GRPCMethods: []GRPCMethod{MethodVAD}, + PossibleUsecases: []string{UsecaseVAD}, + DefaultUsecases: []string{UsecaseVAD}, + Description: "Silero VAD — voice activity detection", + }, +} + +// NormalizeBackendName converts backend names to the canonical hyphenated form +// used in gallery entries (e.g., "llama.cpp" → "llama-cpp"). +func NormalizeBackendName(backend string) string { + return strings.ReplaceAll(backend, ".", "-") +} + +// GetBackendCapability returns the capability info for a backend, or nil if unknown. +// Handles backend name normalization. +func GetBackendCapability(backend string) *BackendCapability { + if cap, ok := BackendCapabilities[NormalizeBackendName(backend)]; ok { + return &cap + } + return nil +} + +// PossibleUsecasesForBackend returns all usecases a backend can support. +// Returns nil if the backend is unknown. +func PossibleUsecasesForBackend(backend string) []string { + if cap := GetBackendCapability(backend); cap != nil { + return cap.PossibleUsecases + } + return nil +} + +// DefaultUsecasesForBackend returns the conservative default usecases. +// Returns nil if the backend is unknown. +func DefaultUsecasesForBackendCap(backend string) []string { + if cap := GetBackendCapability(backend); cap != nil { + return cap.DefaultUsecases + } + return nil +} + +// IsValidUsecaseForBackend checks whether a usecase is in a backend's possible set. +// Returns true for unknown backends (permissive fallback). +func IsValidUsecaseForBackend(backend, usecase string) bool { + cap := GetBackendCapability(backend) + if cap == nil { + return true // unknown backend — don't restrict + } + return slices.Contains(cap.PossibleUsecases, usecase) +} + +// AllBackendNames returns a sorted list of all known backend names. +func AllBackendNames() []string { + names := make([]string, 0, len(BackendCapabilities)) + for name := range BackendCapabilities { + names = append(names, name) + } + slices.Sort(names) + return names +} diff --git a/core/config/backend_capabilities_test.go b/core/config/backend_capabilities_test.go new file mode 100644 index 000000000000..d3ca74a18241 --- /dev/null +++ b/core/config/backend_capabilities_test.go @@ -0,0 +1,116 @@ +package config + +import ( + "slices" + "strings" + "testing" +) + +func TestBackendCapabilities_AllHaveUsecases(t *testing.T) { + for name, cap := range BackendCapabilities { + if len(cap.PossibleUsecases) == 0 { + t.Errorf("backend %q has no possible usecases", name) + } + if len(cap.DefaultUsecases) == 0 { + t.Errorf("backend %q has no default usecases", name) + } + if len(cap.GRPCMethods) == 0 { + t.Errorf("backend %q has no gRPC methods", name) + } + } +} + +func TestBackendCapabilities_DefaultsSubsetOfPossible(t *testing.T) { + for name, cap := range BackendCapabilities { + for _, d := range cap.DefaultUsecases { + if !slices.Contains(cap.PossibleUsecases, d) { + t.Errorf("backend %q: default %q not in possible %v", name, d, cap.PossibleUsecases) + } + } + } +} + +func TestBackendCapabilities_UsecasesMatchFlags(t *testing.T) { + allFlags := GetAllModelConfigUsecases() + for name, cap := range BackendCapabilities { + for _, u := range cap.PossibleUsecases { + info, ok := UsecaseInfoMap[u] + if !ok { + t.Errorf("backend %q: usecase %q not in UsecaseInfoMap", name, u) + continue + } + flagName := "FLAG_" + strings.ToUpper(u) + if _, ok := allFlags[flagName]; !ok { + // Try without transform — some names differ + found := false + for _, flag := range allFlags { + if flag == info.Flag { + found = true + break + } + } + if !found { + t.Errorf("backend %q: usecase %q flag %d not in GetAllModelConfigUsecases", name, u, info.Flag) + } + } + } + } +} + +func TestUsecaseInfoMap_AllHaveFlags(t *testing.T) { + for name, info := range UsecaseInfoMap { + if info.Flag == FLAG_ANY { + t.Errorf("usecase %q has FLAG_ANY (zero) — should have a real flag", name) + } + if info.GRPCMethod == "" { + t.Errorf("usecase %q has no gRPC method", name) + } + } +} + +func TestGetBackendCapability(t *testing.T) { + cap := GetBackendCapability("llama-cpp") + if cap == nil { + t.Fatal("llama-cpp should be known") + } + if !slices.Contains(cap.PossibleUsecases, "chat") { + t.Error("llama-cpp should support chat") + } +} + +func TestGetBackendCapability_Normalize(t *testing.T) { + cap := GetBackendCapability("llama.cpp") + if cap == nil { + t.Fatal("llama.cpp should normalize to llama-cpp") + } +} + +func TestGetBackendCapability_Unknown(t *testing.T) { + cap := GetBackendCapability("nonexistent") + if cap != nil { + t.Error("unknown backend should return nil") + } +} + +func TestIsValidUsecaseForBackend(t *testing.T) { + if !IsValidUsecaseForBackend("piper", "tts") { + t.Error("piper should support tts") + } + if IsValidUsecaseForBackend("piper", "chat") { + t.Error("piper should not support chat") + } + // Unknown backend is permissive + if !IsValidUsecaseForBackend("unknown", "anything") { + t.Error("unknown backend should allow any usecase") + } +} + +func TestAllBackendNames(t *testing.T) { + names := AllBackendNames() + if len(names) < 30 { + t.Errorf("expected 30+ backends, got %d", len(names)) + } + if !slices.IsSorted(names) { + t.Error("should be sorted") + } +} diff --git a/core/config/model_config.go b/core/config/model_config.go index a4815c766755..ecdaf723a533 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -565,11 +565,39 @@ const ( FLAG_VAD ModelConfigUsecase = 0b010000000000 FLAG_VIDEO ModelConfigUsecase = 0b100000000000 FLAG_DETECTION ModelConfigUsecase = 0b1000000000000 + FLAG_VISION ModelConfigUsecase = 0b10000000000000 // Common Subsets FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT ) +// ModalityGroups defines groups of usecases that belong to the same modality. +// Flags within the same group are NOT orthogonal (e.g., chat and completion are +// both text/language). A model is multimodal when its usecases span 2+ groups. +var ModalityGroups = []ModelConfigUsecase{ + FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT, // text/language + FLAG_VISION | FLAG_DETECTION, // visual understanding + FLAG_TRANSCRIPT, // speech input + FLAG_TTS | FLAG_SOUND_GENERATION, // audio output + FLAG_IMAGE | FLAG_VIDEO, // visual generation +} + +// IsMultimodal returns true if the given usecases span two or more orthogonal +// modality groups. For example chat+vision is multimodal, but chat+completion +// is not (both belong to the text/language group). +func IsMultimodal(usecases ModelConfigUsecase) bool { + groupCount := 0 + for _, group := range ModalityGroups { + if usecases&group != 0 { + groupCount++ + if groupCount >= 2 { + return true + } + } + } + return false +} + func GetAllModelConfigUsecases() map[string]ModelConfigUsecase { return map[string]ModelConfigUsecase{ // Note: FLAG_ANY is intentionally excluded from this map @@ -588,6 +616,7 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase { "FLAG_LLM": FLAG_LLM, "FLAG_VIDEO": FLAG_VIDEO, "FLAG_DETECTION": FLAG_DETECTION, + "FLAG_VISION": FLAG_VISION, } } diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index 0b0791afe75f..ef2a9f6c2264 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -7,6 +7,8 @@ import ( "path/filepath" "slices" "strings" + "sync" + "sync/atomic" "time" "github.com/lithammer/fuzzysearch/fuzzy" @@ -92,6 +94,34 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] { return filteredModels } +// FilterGalleryModelsByUsecase returns models whose known_usecases include all +// the bits set in usecase. For example, passing FLAG_CHAT matches any model +// with the chat usecase; passing FLAG_CHAT|FLAG_VISION matches only models +// that have both. +func FilterGalleryModelsByUsecase(models GalleryElements[*GalleryModel], usecase config.ModelConfigUsecase) GalleryElements[*GalleryModel] { + var filtered GalleryElements[*GalleryModel] + for _, m := range models { + u := m.GetKnownUsecases() + if u != nil && (*u&usecase) == usecase { + filtered = append(filtered, m) + } + } + return filtered +} + +// FilterGalleryModelsByMultimodal returns models whose known_usecases span two +// or more orthogonal modality groups (e.g. chat+vision, tts+transcript). +func FilterGalleryModelsByMultimodal(models GalleryElements[*GalleryModel]) GalleryElements[*GalleryModel] { + var filtered GalleryElements[*GalleryModel] + for _, m := range models { + u := m.GetKnownUsecases() + if u != nil && config.IsMultimodal(*u) { + filtered = append(filtered, m) + } + } + return filtered +} + func (gm GalleryElements[T]) FilterByTag(tag string) GalleryElements[T] { var filtered GalleryElements[T] for _, m := range gm { @@ -267,6 +297,69 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst return models, nil } +var ( + availableModelsMu sync.RWMutex + availableModelsCache GalleryElements[*GalleryModel] + refreshing atomic.Bool +) + +// AvailableGalleryModelsCached returns gallery models from an in-memory cache. +// Local-only fields (installed status) are refreshed on every call. A background +// goroutine is triggered to re-fetch the full model list (including network +// calls) so subsequent requests pick up changes without blocking the caller. +// The first call with an empty cache blocks until the initial load completes. +func AvailableGalleryModelsCached(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryModel], error) { + availableModelsMu.RLock() + cached := availableModelsCache + availableModelsMu.RUnlock() + + if cached != nil { + // Refresh installed status under write lock to avoid races with + // concurrent readers and the background refresh goroutine. + availableModelsMu.Lock() + for _, m := range cached { + _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", m.GetName()))) + m.SetInstalled(err == nil) + } + availableModelsMu.Unlock() + // Trigger a background refresh if one is not already running. + triggerGalleryRefresh(galleries, systemState) + return cached, nil + } + + // No cache yet — must do a blocking load. + models, err := AvailableGalleryModels(galleries, systemState) + if err != nil { + return nil, err + } + + availableModelsMu.Lock() + availableModelsCache = models + availableModelsMu.Unlock() + + return models, nil +} + +// triggerGalleryRefresh starts a background goroutine that refreshes the +// gallery model cache. Only one refresh runs at a time; concurrent calls +// are no-ops. +func triggerGalleryRefresh(galleries []config.Gallery, systemState *system.SystemState) { + if !refreshing.CompareAndSwap(false, true) { + return + } + go func() { + defer refreshing.Store(false) + models, err := AvailableGalleryModels(galleries, systemState) + if err != nil { + xlog.Error("background gallery refresh failed", "error", err) + return + } + availableModelsMu.Lock() + availableModelsCache = models + availableModelsMu.Unlock() + }() +} + // List available backends func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) { return availableBackendsWithFilter(galleries, systemState, true) diff --git a/core/gallery/importers/diffuser.go b/core/gallery/importers/diffuser.go index c702da3d3025..1060899aa7a5 100644 --- a/core/gallery/importers/diffuser.go +++ b/core/gallery/importers/diffuser.go @@ -93,7 +93,7 @@ func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error) modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"image"}, + KnownUsecaseStrings: []string{config.UsecaseImage}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ diff --git a/core/gallery/importers/llama-cpp.go b/core/gallery/importers/llama-cpp.go index edd9387913c3..45e91154e347 100644 --- a/core/gallery/importers/llama-cpp.go +++ b/core/gallery/importers/llama-cpp.go @@ -104,7 +104,7 @@ func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Options: []string{"use_jinja:true"}, Backend: "llama-cpp", TemplateConfig: config.TemplateConfig{ diff --git a/core/gallery/importers/local.go b/core/gallery/importers/local.go index 2a456cc6020d..73020ceceea6 100644 --- a/core/gallery/importers/local.go +++ b/core/gallery/importers/local.go @@ -42,7 +42,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { cfg := &config.ModelConfig{ Name: name, Backend: "llama-cpp", - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Options: []string{"use_jinja:true"}, } cfg.Model = relPath(ggufFile) @@ -60,7 +60,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { cfg := &config.ModelConfig{ Name: name, Backend: "transformers", - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, } cfg.Model = baseModel cfg.TemplateConfig.UseTokenizerTemplate = true @@ -76,7 +76,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { cfg := &config.ModelConfig{ Name: name, Backend: "transformers", - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, } cfg.Model = baseModel cfg.TemplateConfig.UseTokenizerTemplate = true @@ -91,7 +91,7 @@ func ImportLocalPath(dirPath, name string) (*config.ModelConfig, error) { cfg := &config.ModelConfig{ Name: name, Backend: "transformers", - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, } cfg.Model = relPath(dirPath) cfg.TemplateConfig.UseTokenizerTemplate = true diff --git a/core/gallery/importers/mlx.go b/core/gallery/importers/mlx.go index 7ab513f6dd19..feac13129015 100644 --- a/core/gallery/importers/mlx.go +++ b/core/gallery/importers/mlx.go @@ -69,7 +69,7 @@ func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) { modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ diff --git a/core/gallery/importers/transformers.go b/core/gallery/importers/transformers.go index 5a4732ca896c..dbed7402dcf1 100644 --- a/core/gallery/importers/transformers.go +++ b/core/gallery/importers/transformers.go @@ -83,7 +83,7 @@ func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, err modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ diff --git a/core/gallery/importers/vllm.go b/core/gallery/importers/vllm.go index 88baef1fefa8..20439da52451 100644 --- a/core/gallery/importers/vllm.go +++ b/core/gallery/importers/vllm.go @@ -73,7 +73,7 @@ func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) { modelConfig := config.ModelConfig{ Name: name, Description: description, - KnownUsecaseStrings: []string{"chat"}, + KnownUsecaseStrings: []string{config.UsecaseChat}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ diff --git a/core/gallery/models_types.go b/core/gallery/models_types.go index f70a5b222567..c3de03efcbfa 100644 --- a/core/gallery/models_types.go +++ b/core/gallery/models_types.go @@ -52,3 +52,23 @@ func (m *GalleryModel) GetTags() []string { func (m *GalleryModel) GetDescription() string { return m.Description } + +// GetKnownUsecases extracts known_usecases from the model's Overrides and +// returns the parsed usecase flags. Returns nil when no usecases are declared. +func (m *GalleryModel) GetKnownUsecases() *config.ModelConfigUsecase { + raw, ok := m.Overrides["known_usecases"] + if !ok { + return nil + } + list, ok := raw.([]any) + if !ok { + return nil + } + strs := make([]string, 0, len(list)) + for _, v := range list { + if s, ok := v.(string); ok { + strs = append(strs, s) + } + } + return config.GetUsecasesFromYAML(strs) +} diff --git a/core/http/endpoints/localai/config_meta.go b/core/http/endpoints/localai/config_meta.go index 22d055b999a8..103081c30eb8 100644 --- a/core/http/endpoints/localai/config_meta.go +++ b/core/http/endpoints/localai/config_meta.go @@ -120,13 +120,13 @@ func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a capability := strings.TrimPrefix(provider, "models:") var filterFn config.ModelConfigFilterFn switch capability { - case "chat": + case config.UsecaseChat: filterFn = config.BuildUsecaseFilterFn(config.FLAG_CHAT) - case "tts": + case config.UsecaseTTS: filterFn = config.BuildUsecaseFilterFn(config.FLAG_TTS) - case "vad": + case config.UsecaseVAD: filterFn = config.BuildUsecaseFilterFn(config.FLAG_VAD) - case "transcript": + case config.UsecaseTranscript: filterFn = config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT) default: filterFn = config.NoFilterFn diff --git a/core/http/react-ui/e2e/models-gallery.spec.js b/core/http/react-ui/e2e/models-gallery.spec.js index ed5be1e56f5a..f0936c436299 100644 --- a/core/http/react-ui/e2e/models-gallery.spec.js +++ b/core/http/react-ui/e2e/models-gallery.spec.js @@ -2,13 +2,13 @@ import { test, expect } from '@playwright/test' const MOCK_MODELS_RESPONSE = { models: [ - { name: 'llama-model', description: 'A llama model', backend: 'llama-cpp', installed: false, tags: ['llm'] }, - { name: 'whisper-model', description: 'A whisper model', backend: 'whisper', installed: true, tags: ['stt'] }, + { name: 'llama-model', description: 'A llama model', backend: 'llama-cpp', installed: false, tags: ['chat'] }, + { name: 'whisper-model', description: 'A whisper model', backend: 'whisper', installed: true, tags: ['transcript'] }, { name: 'stablediffusion-model', description: 'An image model', backend: 'stablediffusion', installed: false, tags: ['sd'] }, { name: 'unknown-model', description: 'No backend', backend: '', installed: false, tags: [] }, ], allBackends: ['llama-cpp', 'stablediffusion', 'whisper'], - allTags: ['llm', 'sd', 'stt'], + allTags: ['chat', 'sd', 'transcript'], availableModels: 4, installedModels: 1, totalPages: 1, @@ -78,3 +78,121 @@ test.describe('Models Gallery - Backend Features', () => { await expect(detail.locator('text=llama-cpp')).toBeVisible() }) }) + +const BACKEND_USECASES_MOCK = { + 'llama-cpp': ['chat', 'embeddings', 'vision'], + 'whisper': ['transcript'], + 'stablediffusion': ['image'], +} + +test.describe('Models Gallery - Multi-select Filters', () => { + test.beforeEach(async ({ page }) => { + await page.route('**/api/models*', (route) => { + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify(MOCK_MODELS_RESPONSE), + }) + }) + await page.route('**/api/backends/usecases', (route) => { + route.fulfill({ + contentType: 'application/json', + body: JSON.stringify(BACKEND_USECASES_MOCK), + }) + }) + await page.goto('/app/models') + await expect(page.locator('th', { hasText: 'Backend' })).toBeVisible({ timeout: 10_000 }) + }) + + test('multi-select toggle: click Chat, TTS, then Chat again', async ({ page }) => { + const chatBtn = page.locator('.filter-btn', { hasText: 'Chat' }) + const ttsBtn = page.locator('.filter-btn', { hasText: 'TTS' }) + + await chatBtn.click() + await expect(chatBtn).toHaveClass(/active/) + + await ttsBtn.click() + await expect(chatBtn).toHaveClass(/active/) + await expect(ttsBtn).toHaveClass(/active/) + + // Click Chat again to deselect it + await chatBtn.click() + await expect(chatBtn).not.toHaveClass(/active/) + await expect(ttsBtn).toHaveClass(/active/) + }) + + test('"All" clears selection', async ({ page }) => { + const chatBtn = page.locator('.filter-btn', { hasText: 'Chat' }) + const allBtn = page.locator('.filter-btn', { hasText: 'All' }) + + await chatBtn.click() + await expect(chatBtn).toHaveClass(/active/) + + await allBtn.click() + await expect(allBtn).toHaveClass(/active/) + await expect(chatBtn).not.toHaveClass(/active/) + }) + + test('query param sent correctly with multiple filters', async ({ page }) => { + const chatBtn = page.locator('.filter-btn', { hasText: 'Chat' }) + const ttsBtn = page.locator('.filter-btn', { hasText: 'TTS' }) + + // Click Chat and wait for its request to settle + await chatBtn.click() + await page.waitForResponse(resp => resp.url().includes('/api/models')) + + // Now click TTS and capture the resulting request + const [request] = await Promise.all([ + page.waitForRequest(req => { + if (!req.url().includes('/api/models')) return false + const u = new URL(req.url()) + const tag = u.searchParams.get('tag') + return tag && tag.split(',').length >= 2 + }), + ttsBtn.click(), + ]) + + const url = new URL(request.url()) + const tags = url.searchParams.get('tag').split(',').sort() + expect(tags).toEqual(['chat', 'tts']) + }) + + test('backend greys out unavailable filters', async ({ page }) => { + // Select llama-cpp backend via dropdown + await page.locator('button', { hasText: 'All Backends' }).click() + const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..').locator('..') + await dropdown.locator('text=llama-cpp').click() + + // Wait for filter state to update + const ttsBtn = page.locator('.filter-btn', { hasText: 'TTS' }) + const sttBtn = page.locator('.filter-btn', { hasText: 'STT' }) + const imageBtn = page.locator('.filter-btn', { hasText: 'Image' }) + + // TTS, STT, Image should be disabled for llama-cpp + await expect(ttsBtn).toBeDisabled() + await expect(sttBtn).toBeDisabled() + await expect(imageBtn).toBeDisabled() + + // Chat, Embeddings, Vision should remain enabled + const chatBtn = page.locator('.filter-btn', { hasText: 'Chat' }) + const embBtn = page.locator('.filter-btn', { hasText: 'Embeddings' }) + const visBtn = page.locator('.filter-btn', { hasText: 'Vision' }) + await expect(chatBtn).toBeEnabled() + await expect(embBtn).toBeEnabled() + await expect(visBtn).toBeEnabled() + }) + + test('backend clears incompatible filters', async ({ page }) => { + // Select TTS filter first + const ttsBtn = page.locator('.filter-btn', { hasText: 'TTS' }) + await ttsBtn.click() + await expect(ttsBtn).toHaveClass(/active/) + + // Now select llama-cpp backend (which doesn't support TTS) + await page.locator('button', { hasText: 'All Backends' }).click() + const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..').locator('..') + await dropdown.locator('text=llama-cpp').click() + + // TTS should be auto-removed from selection + await expect(ttsBtn).not.toHaveClass(/active/) + }) +}) diff --git a/core/http/react-ui/src/pages/Backends.jsx b/core/http/react-ui/src/pages/Backends.jsx index 61e6468b9b69..aa63b463ebf4 100644 --- a/core/http/react-ui/src/pages/Backends.jsx +++ b/core/http/react-ui/src/pages/Backends.jsx @@ -139,11 +139,11 @@ export default function Backends() { const FILTERS = [ { key: '', label: 'All', icon: 'fa-layer-group' }, - { key: 'llm', label: 'LLM', icon: 'fa-brain' }, + { key: 'chat', label: 'Chat', icon: 'fa-brain' }, { key: 'image', label: 'Image', icon: 'fa-image' }, { key: 'video', label: 'Video', icon: 'fa-video' }, { key: 'tts', label: 'TTS', icon: 'fa-microphone' }, - { key: 'stt', label: 'STT', icon: 'fa-headphones' }, + { key: 'transcript', label: 'STT', icon: 'fa-headphones' }, { key: 'vision', label: 'Vision', icon: 'fa-eye' }, ] diff --git a/core/http/react-ui/src/pages/Models.jsx b/core/http/react-ui/src/pages/Models.jsx index fd1f3f6c9bd9..2192710befea 100644 --- a/core/http/react-ui/src/pages/Models.jsx +++ b/core/http/react-ui/src/pages/Models.jsx @@ -88,14 +88,14 @@ function GalleryLoader() { const FILTERS = [ { key: '', label: 'All', icon: 'fa-layer-group' }, - { key: 'llm', label: 'LLM', icon: 'fa-brain' }, - { key: 'sd', label: 'Image', icon: 'fa-image' }, + { key: 'chat', label: 'Chat', icon: 'fa-brain' }, + { key: 'image', label: 'Image', icon: 'fa-image' }, { key: 'multimodal', label: 'Multimodal', icon: 'fa-shapes' }, { key: 'vision', label: 'Vision', icon: 'fa-eye' }, { key: 'tts', label: 'TTS', icon: 'fa-microphone' }, - { key: 'stt', label: 'STT', icon: 'fa-headphones' }, - { key: 'embedding', label: 'Embedding', icon: 'fa-vector-square' }, - { key: 'reranker', label: 'Rerank', icon: 'fa-sort' }, + { key: 'transcript', label: 'STT', icon: 'fa-headphones' }, + { key: 'embeddings', label: 'Embeddings', icon: 'fa-vector-square' }, + { key: 'rerank', label: 'Rerank', icon: 'fa-sort' }, ] export default function Models() { @@ -108,7 +108,7 @@ export default function Models() { const [page, setPage] = useState(1) const [totalPages, setTotalPages] = useState(1) const [search, setSearch] = useState('') - const [filter, setFilter] = useState('') + const [filters, setFilters] = useState([]) const [sort, setSort] = useState('') const [order, setOrder] = useState('asc') const [installing, setInstalling] = useState(new Map()) @@ -117,6 +117,8 @@ export default function Models() { const [stats, setStats] = useState({ total: 0, installed: 0, repositories: 0 }) const [backendFilter, setBackendFilter] = useState('') const [allBackends, setAllBackends] = useState([]) + const [backendUsecases, setBackendUsecases] = useState({}) + const [estimates, setEstimates] = useState({}) const debounceRef = useRef(null) const [confirmDialog, setConfirmDialog] = useState(null) @@ -127,14 +129,14 @@ export default function Models() { try { setLoading(true) const searchVal = params.search !== undefined ? params.search : search - const filterVal = params.filter !== undefined ? params.filter : filter + const filtersVal = params.filters !== undefined ? params.filters : filters const sortVal = params.sort !== undefined ? params.sort : sort const backendVal = params.backendFilter !== undefined ? params.backendFilter : backendFilter const queryParams = { page: params.page || page, items: 9, } - if (filterVal) queryParams.tag = filterVal + if (filtersVal.length > 0) queryParams.tag = filtersVal.join(',') if (searchVal) queryParams.term = searchVal if (backendVal) queryParams.backend = backendVal if (sortVal) { @@ -154,17 +156,50 @@ export default function Models() { } finally { setLoading(false) } - }, [page, search, filter, sort, order, backendFilter, addToast]) + }, [page, search, filters, sort, order, backendFilter, addToast]) useEffect(() => { fetchModels() - }, [page, filter, sort, order, backendFilter]) + }, [page, filters, sort, order, backendFilter]) + + // Fetch backend→usecase mapping once on mount + useEffect(() => { + modelsApi.backendUsecases().then(setBackendUsecases).catch(() => {}) + }, []) + + // When backend changes, remove selected filters that aren't available + useEffect(() => { + if (backendFilter && backendUsecases[backendFilter]) { + setFilters(prev => { + const possible = backendUsecases[backendFilter] + const filtered = prev.filter(k => k === 'multimodal' || possible.includes(k)) + return filtered.length !== prev.length ? filtered : prev + }) + } + }, [backendFilter, backendUsecases]) // Re-fetch when operations change (install/delete completion) useEffect(() => { if (!loading) fetchModels() }, [operations.length]) + // Fetch VRAM/size estimates asynchronously for visible models. + useEffect(() => { + if (models.length === 0) return + let cancelled = false + models.forEach(model => { + const id = model.name || model.id + if (estimates[id]) return + modelsApi.estimate(id).then(est => { + if (cancelled) return + if (est && (est.SizeBytes || est.VRAMBytes)) { + setEstimates(prev => ({ ...prev, [id]: est })) + } + }).catch(() => {}) + }) + return () => { cancelled = true } + }, [models]) + const handleSearch = (value) => { setSearch(value) if (debounceRef.current) clearTimeout(debounceRef.current) @@ -174,6 +209,20 @@ export default function Models() { }, 500) } + const toggleFilter = (key) => { + if (key === '') { setFilters([]); setPage(1); return } + setFilters(prev => + prev.includes(key) ? prev.filter(k => k !== key) : [...prev, key] + ) + setPage(1) + } + + const isFilterAvailable = (key) => { + if (!backendFilter || key === '' || key === 'multimodal') return true + const possible = backendUsecases[backendFilter] + return !possible || possible.includes(key) + } + const handleSort = (col) => { if (sort === col) { setOrder(o => o === 'asc' ? 'desc' : 'asc') @@ -292,16 +341,23 @@ export default function Models() { {/* Filter buttons */}
- {FILTERS.map(f => ( - - ))} + {FILTERS.map(f => { + const isAll = f.key === '' + const active = isAll ? filters.length === 0 : filters.includes(f.key) + const available = isFilterAvailable(f.key) + return ( + + ) + })} {allBackends.length > 0 && (

No models found

- {search || filter || backendFilter + {search || filters.length > 0 || backendFilter ? 'No models match your current search or filters.' : 'The model gallery is empty.'}

- {(search || filter || backendFilter) && ( + {(search || filters.length > 0 || backendFilter) && ( @@ -359,9 +415,13 @@ export default function Models() { {models.map((model, idx) => { const name = model.name || model.id + const est = estimates[name] || {} + const sizeDisplay = est.SizeDisplay || model.estimated_size_display + const vramDisplay = est.VRAMDisplay || model.estimated_vram_display + const vramBytes = est.VRAMBytes || model.estimated_vram_bytes const installing = isInstalling(name) const progress = getOperationProgress(name) - const fit = fitsGpu(model.estimated_vram_bytes) + const fit = fitsGpu(vramBytes) const isExpanded = expandedRow === idx return ( @@ -428,15 +488,15 @@ export default function Models() { {/* Size / VRAM */}
- {(model.estimated_size_display || model.estimated_vram_display) ? ( + {(sizeDisplay || vramDisplay) ? ( <> - {model.estimated_size_display && model.estimated_size_display !== '0 B' && ( - Size: {model.estimated_size_display} + {sizeDisplay && sizeDisplay !== '0 B' && ( + Size: {sizeDisplay} )} - {model.estimated_size_display && model.estimated_size_display !== '0 B' && model.estimated_vram_display && model.estimated_vram_display !== '0 B' && ' · '} - {model.estimated_vram_display && model.estimated_vram_display !== '0 B' && ( - VRAM: {model.estimated_vram_display} + {sizeDisplay && sizeDisplay !== '0 B' && vramDisplay && vramDisplay !== '0 B' && ' · '} + {vramDisplay && vramDisplay !== '0 B' && ( + VRAM: {vramDisplay} )} {fit !== null && ( @@ -509,7 +569,7 @@ export default function Models() { {isExpanded && ( - + )} @@ -562,7 +622,7 @@ function DetailRow({ label, children }) { ) } -function ModelDetail({ model, fit, expandedFiles, setExpandedFiles }) { +function ModelDetail({ model, fit, sizeDisplay, vramDisplay, expandedFiles, setExpandedFiles }) { const files = model.additionalFiles || model.files || [] return (
@@ -588,12 +648,12 @@ function ModelDetail({ model, fit, expandedFiles, setExpandedFiles }) { )} - {model.estimated_size_display && model.estimated_size_display !== '0 B' ? model.estimated_size_display : null} + {sizeDisplay && sizeDisplay !== '0 B' ? sizeDisplay : null} - {model.estimated_vram_display && model.estimated_vram_display !== '0 B' ? ( + {vramDisplay && vramDisplay !== '0 B' ? ( - {model.estimated_vram_display} + {vramDisplay} {fit !== null && ( {fit ? 'Fits in GPU' : 'May not fit in GPU'} diff --git a/core/http/react-ui/src/utils/api.js b/core/http/react-ui/src/utils/api.js index ea967f973a42..ad3aba85db4b 100644 --- a/core/http/react-ui/src/utils/api.js +++ b/core/http/react-ui/src/utils/api.js @@ -79,6 +79,7 @@ export const modelsApi = { listCapabilities: () => fetchJSON(API_CONFIG.endpoints.modelsCapabilities), install: (id) => postJSON(API_CONFIG.endpoints.installModel(id), {}), delete: (id) => postJSON(API_CONFIG.endpoints.deleteModel(id), {}), + estimate: (id) => fetchJSON(API_CONFIG.endpoints.modelEstimate(id)), getConfig: (id) => postJSON(API_CONFIG.endpoints.modelConfig(id), {}), getConfigJson: (name) => fetchJSON(API_CONFIG.endpoints.modelConfigJson(name)), getJob: (uid) => fetchJSON(API_CONFIG.endpoints.modelJob(uid)), @@ -97,6 +98,7 @@ export const modelsApi = { getJobStatus: (uid) => fetchJSON(API_CONFIG.endpoints.modelsJobStatus(uid)), getEditConfig: (name) => fetchJSON(API_CONFIG.endpoints.modelEditGet(name)), editConfig: (name, body) => postJSON(API_CONFIG.endpoints.modelEdit(name), body), + backendUsecases: () => fetchJSON('/api/backends/usecases'), } // Backends API diff --git a/core/http/react-ui/src/utils/config.js b/core/http/react-ui/src/utils/config.js index e25228dcfd7a..b4b9d4fef165 100644 --- a/core/http/react-ui/src/utils/config.js +++ b/core/http/react-ui/src/utils/config.js @@ -9,6 +9,7 @@ export const API_CONFIG = { models: '/api/models', installModel: (id) => `/api/models/install/${id}`, deleteModel: (id) => `/api/models/delete/${id}`, + modelEstimate: (id) => `/api/models/estimate/${id}`, modelConfig: (id) => `/api/models/config/${id}`, modelConfigJson: (name) => `/api/models/config-json/${name}`, modelJob: (uid) => `/api/models/job/${uid}`, diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go index 24cdba24aa1d..0eca3093ada5 100644 --- a/core/http/routes/ui_api.go +++ b/core/http/routes/ui_api.go @@ -13,7 +13,6 @@ import ( "slices" "strconv" "strings" - "sync" "time" "github.com/google/uuid" @@ -38,8 +37,60 @@ const ( licenseSortFieldName = "license" statusSortFieldName = "status" ascSortOrder = "asc" + multimodalFilterKey = "multimodal" ) +var galleryWeightExts = map[string]bool{".gguf": true, ".safetensors": true, ".bin": true, ".pt": true} + +// usecaseFilters maps UI filter keys to ModelConfigUsecase flags for +// capability-based gallery filtering. +var usecaseFilters = map[string]config.ModelConfigUsecase{ + config.UsecaseChat: config.FLAG_CHAT, + config.UsecaseImage: config.FLAG_IMAGE, + config.UsecaseVision: config.FLAG_VISION, + config.UsecaseTTS: config.FLAG_TTS, + config.UsecaseTranscript: config.FLAG_TRANSCRIPT, + config.UsecaseEmbeddings: config.FLAG_EMBEDDINGS, + config.UsecaseRerank: config.FLAG_RERANK, +} + + +// extractHFRepo tries to find a HuggingFace repo ID from model overrides or URLs. +func extractHFRepo(overrides map[string]any, urls []string) string { + if overrides != nil { + if params, ok := overrides["parameters"].(map[string]any); ok { + if modelRef, ok := params["model"].(string); ok { + if repoID, ok := vram.ExtractHFRepoID(modelRef); ok { + return repoID + } + } + } + } + for _, u := range urls { + if repoID, ok := vram.ExtractHFRepoID(u); ok { + return repoID + } + } + return "" +} + +// buildEstimateInput creates a vram.ModelEstimateInput from gallery model metadata. +func buildEstimateInput(m *gallery.GalleryModel) vram.ModelEstimateInput { + var input vram.ModelEstimateInput + input.Options = vram.EstimateOptions{ContextLength: 8192} + input.Size = m.Size + if hfRepoID := extractHFRepo(m.Overrides, m.URLs); hfRepoID != "" { + input.HFRepo = hfRepoID + } + for _, f := range m.AdditionalFiles { + ext := strings.ToLower(path.Ext(path.Base(f.URI))) + if galleryWeightExts[ext] { + input.Files = append(input.Files, vram.FileInput{URI: f.URI, Size: 0}) + } + } + return input +} + // getDirectorySize calculates the total size of files in a directory func getDirectorySize(path string) (int64, error) { var totalSize int64 @@ -221,7 +272,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model items = "9" } - models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState) + models, err := gallery.AvailableGalleryModelsCached(appConfig.Galleries, appConfig.SystemState) if err != nil { xlog.Error("could not list models from galleries", "error", err) return c.JSON(http.StatusInternalServerError, map[string]any{ @@ -255,8 +306,30 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model } slices.Sort(backendNames) + // Filter by usecase tags (comma-separated for multi-select). if tag != "" { - models = gallery.GalleryElements[*gallery.GalleryModel](models).FilterByTag(tag) + var combinedFlag config.ModelConfigUsecase + hasMultimodal := false + var plainTags []string + for _, t := range strings.Split(tag, ",") { + t = strings.TrimSpace(t) + if t == multimodalFilterKey { + hasMultimodal = true + } else if flag, ok := usecaseFilters[t]; ok { + combinedFlag |= flag + } else if t != "" { + plainTags = append(plainTags, t) + } + } + if hasMultimodal { + models = gallery.FilterGalleryModelsByMultimodal(models) + } + if combinedFlag != config.FLAG_ANY { + models = gallery.FilterGalleryModelsByUsecase(models, combinedFlag) + } + for _, pt := range plainTags { + models = gallery.GalleryElements[*gallery.GalleryModel](models).FilterByTag(pt) + } } if term != "" { models = gallery.GalleryElements[*gallery.GalleryModel](models).Search(term) @@ -316,41 +389,6 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model modelsJSON := make([]map[string]any, 0, len(models)) seenIDs := make(map[string]bool) - weightExts := map[string]bool{".gguf": true, ".safetensors": true, ".bin": true, ".pt": true} - extractHFRepo := func(overrides map[string]any, urls []string) string { - // Try overrides.parameters.model first - if overrides != nil { - if params, ok := overrides["parameters"].(map[string]any); ok { - if modelRef, ok := params["model"].(string); ok { - if repoID, ok := vram.ExtractHFRepoID(modelRef); ok { - return repoID - } - } - } - } - // Fall back to the first HuggingFace URL in the metadata urls list - for _, u := range urls { - if repoID, ok := vram.ExtractHFRepoID(u); ok { - return repoID - } - } - return "" - } - hasWeightFiles := func(files []gallery.File) bool { - for _, f := range files { - ext := strings.ToLower(path.Ext(path.Base(f.URI))) - if weightExts[ext] { - return true - } - } - return false - } - - const hfEstimateTimeout = 10 * time.Second - const estimateConcurrency = 3 - sem := make(chan struct{}, estimateConcurrency) - var wg sync.WaitGroup - for _, m := range models { modelID := m.ID() @@ -392,63 +430,9 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model "backend": m.Backend, } - // Build EstimateModel input from available metadata - var estimateInput vram.ModelEstimateInput - estimateInput.Options = vram.EstimateOptions{ContextLength: 8192} - estimateInput.Size = m.Size - if hfRepoID := extractHFRepo(m.Overrides, m.URLs); hfRepoID != "" { - estimateInput.HFRepo = hfRepoID - } - - if hasWeightFiles(m.AdditionalFiles) { - files := make([]gallery.File, len(m.AdditionalFiles)) - copy(files, m.AdditionalFiles) - for _, f := range files { - ext := strings.ToLower(path.Ext(path.Base(f.URI))) - if weightExts[ext] { - estimateInput.Files = append(estimateInput.Files, vram.FileInput{URI: f.URI, Size: 0}) - } - } - } - - // Run estimation (async for file-based and HF repo, sync for size string only) - needsAsync := len(estimateInput.Files) > 0 || estimateInput.HFRepo != "" - if needsAsync { - input := estimateInput - wg.Go(func() { - sem <- struct{}{} - defer func() { <-sem }() - ctx, cancel := context.WithTimeout(context.Background(), hfEstimateTimeout) - defer cancel() - result, err := vram.EstimateModel(ctx, input) - if err == nil { - if result.SizeBytes > 0 { - obj["estimated_size_bytes"] = result.SizeBytes - obj["estimated_size_display"] = result.SizeDisplay - } - if result.VRAMBytes > 0 { - obj["estimated_vram_bytes"] = result.VRAMBytes - obj["estimated_vram_display"] = result.VRAMDisplay - } - } - }) - } else if estimateInput.Size != "" { - result, _ := vram.EstimateModel(context.Background(), estimateInput) - if result.SizeBytes > 0 { - obj["estimated_size_bytes"] = result.SizeBytes - obj["estimated_size_display"] = result.SizeDisplay - } - if result.VRAMBytes > 0 { - obj["estimated_vram_bytes"] = result.VRAMBytes - obj["estimated_vram_display"] = result.VRAMDisplay - } - } - modelsJSON = append(modelsJSON, obj) } - wg.Wait() - prevPage := pageNum - 1 nextPage := pageNum + 1 if prevPage < 1 { @@ -535,6 +519,67 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model }) }) + // Returns a mapping of backend names to the usecase filter keys they support. + // Used by the gallery frontend to grey out usecase filter buttons when a + // backend is selected. + app.GET("/api/backends/usecases", func(c echo.Context) error { + result := make(map[string][]string, len(config.BackendCapabilities)) + for name, cap := range config.BackendCapabilities { + var keys []string + for _, uc := range cap.PossibleUsecases { + if _, ok := usecaseFilters[uc]; ok { + keys = append(keys, uc) + } + } + slices.Sort(keys) + result[name] = keys + } + + return c.JSON(200, result) + }, adminMiddleware) + + // Returns VRAM/size estimates for a single gallery model. The frontend + // calls this per-model so the gallery page can load instantly and fill + // in estimates asynchronously. + app.GET("/api/models/estimate/:id", func(c echo.Context) error { + modelID, err := url.QueryUnescape(c.Param("id")) + if err != nil { + return c.JSON(http.StatusBadRequest, map[string]any{"error": "invalid model ID"}) + } + + // Return cached result immediately if available. + if cached, ok := vram.GetCachedEstimate(modelID); ok { + return c.JSON(200, cached) + } + + // Look up the model from the gallery to build the estimate input. + models, err := gallery.AvailableGalleryModelsCached(appConfig.Galleries, appConfig.SystemState) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()}) + } + + model := gallery.FindGalleryElement(models, modelID) + if model == nil { + return c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"}) + } + + input := buildEstimateInput(model) + if len(input.Files) == 0 && input.HFRepo == "" && input.Size == "" { + return c.JSON(200, vram.EstimateResult{}) + } + + ctx, cancel := context.WithTimeout(c.Request().Context(), 10*time.Second) + defer cancel() + result, err := vram.EstimateModel(ctx, input) + if err != nil { + xlog.Debug("model estimate failed", "model", modelID, "error", err) + return c.JSON(200, vram.EstimateResult{}) + } + + vram.SetCachedEstimate(modelID, result) + return c.JSON(200, result) + }, adminMiddleware) + app.POST("/api/models/install/:id", func(c echo.Context) error { galleryID := c.Param("id") // URL decode the gallery ID (e.g., "localai%40model" -> "localai@model") @@ -638,7 +683,7 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model } xlog.Debug("API job submitted to get config", "galleryID", galleryID) - models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState) + models, err := gallery.AvailableGalleryModelsCached(appConfig.Galleries, appConfig.SystemState) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]any{ "error": err.Error(), diff --git a/pkg/vram/cache.go b/pkg/vram/cache.go index 38fd08b29666..08cf1b7ae7ba 100644 --- a/pkg/vram/cache.go +++ b/pkg/vram/cache.go @@ -94,3 +94,44 @@ var ( defaultCachedSizeResolver = CachedSizeResolver(defaultSizeResolver{}, defaultEstimateCacheTTL) defaultCachedGGUFReader = CachedGGUFReader(defaultGGUFReader{}, defaultEstimateCacheTTL) ) + +// Model-level estimate result cache — keyed by model ID, avoids re-running +// the full estimation pipeline (HTTP HEAD, GGUF reads, HF API) on every +// gallery page load. + +const estimateResultTTL = 1 * time.Hour + +type estimateResultEntry struct { + result EstimateResult + until time.Time +} + +var ( + estimateResultMu sync.Mutex + estimateResultCache = make(map[string]estimateResultEntry) +) + +// GetCachedEstimate returns a previously cached EstimateResult for the given +// key (typically a model ID). Returns false on cache miss or expiry. +func GetCachedEstimate(key string) (EstimateResult, bool) { + estimateResultMu.Lock() + defer estimateResultMu.Unlock() + e, ok := estimateResultCache[key] + if !ok || time.Now().After(e.until) { + if ok { + delete(estimateResultCache, key) + } + return EstimateResult{}, false + } + return e.result, true +} + +// SetCachedEstimate stores an EstimateResult for the given key with a 1-hour TTL. +func SetCachedEstimate(key string, result EstimateResult) { + estimateResultMu.Lock() + defer estimateResultMu.Unlock() + estimateResultCache[key] = estimateResultEntry{ + result: result, + until: time.Now().Add(estimateResultTTL), + } +} From e98d30f8a7eb8b8bde1c609df61f90bd8cbee51b Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Mon, 6 Apr 2026 11:44:32 +0100 Subject: [PATCH 2/2] feat(ui): Asynchronous VRAM estimates with multi-context and use known_usecases Signed-off-by: Richard Palethorpe --- core/application/startup.go | 5 + core/gallery/gallery.go | 8 + core/http/endpoints/localai/import_model.go | 13 +- core/http/endpoints/localai/vram.go | 84 ++++------ core/http/react-ui/src/pages/Models.jsx | 36 ++++- core/http/react-ui/src/utils/api.js | 5 +- core/http/routes/ui_api.go | 46 ++++-- core/services/nodes/router.go | 27 ++-- pkg/vram/cache.go | 112 +++++--------- pkg/vram/estimate.go | 162 +++++++++++++------- pkg/vram/estimate_test.go | 114 +++++++++++--- pkg/vram/hf_estimate.go | 60 +++----- pkg/vram/types.go | 40 +++-- 13 files changed, 404 insertions(+), 308 deletions(-) diff --git a/core/application/startup.go b/core/application/startup.go index 728c3c97221e..bb7ec82750c8 100644 --- a/core/application/startup.go +++ b/core/application/startup.go @@ -17,6 +17,7 @@ import ( "github.com/mudler/LocalAI/core/services/jobs" "github.com/mudler/LocalAI/core/services/nodes" "github.com/mudler/LocalAI/core/services/storage" + "github.com/mudler/LocalAI/pkg/vram" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" @@ -231,6 +232,10 @@ func New(opts ...config.AppOption) (*Application, error) { xlog.Error("error registering external backends", "error", err) } + // Wire gallery generation counter into VRAM caches so they invalidate + // when gallery data refreshes instead of using a fixed TTL. + vram.SetGalleryGenerationFunc(gallery.GalleryGeneration) + if options.ConfigFile != "" { if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { xlog.Error("error loading config file", "error", err) diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go index ef2a9f6c2264..b7667b234bc7 100644 --- a/core/gallery/gallery.go +++ b/core/gallery/gallery.go @@ -301,8 +301,14 @@ var ( availableModelsMu sync.RWMutex availableModelsCache GalleryElements[*GalleryModel] refreshing atomic.Bool + galleryGeneration atomic.Uint64 ) +// GalleryGeneration returns a counter that increments each time the gallery +// model list is refreshed from upstream. VRAM estimation caches use this to +// invalidate entries when the gallery data changes. +func GalleryGeneration() uint64 { return galleryGeneration.Load() } + // AvailableGalleryModelsCached returns gallery models from an in-memory cache. // Local-only fields (installed status) are refreshed on every call. A background // goroutine is triggered to re-fetch the full model list (including network @@ -335,6 +341,7 @@ func AvailableGalleryModelsCached(galleries []config.Gallery, systemState *syste availableModelsMu.Lock() availableModelsCache = models + galleryGeneration.Add(1) availableModelsMu.Unlock() return models, nil @@ -356,6 +363,7 @@ func triggerGalleryRefresh(galleries []config.Gallery, systemState *system.Syste } availableModelsMu.Lock() availableModelsCache = models + galleryGeneration.Add(1) availableModelsMu.Unlock() }() } diff --git a/core/http/endpoints/localai/import_model.go b/core/http/endpoints/localai/import_model.go index a1931bae9117..41921c6848d8 100644 --- a/core/http/endpoints/localai/import_model.go +++ b/core/http/endpoints/localai/import_model.go @@ -51,18 +51,17 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl } estCtx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second) defer cancel() - result, err := vram.EstimateModel(estCtx, vram.ModelEstimateInput{ - Files: files, - Options: vram.EstimateOptions{ContextLength: 8192}, - }) + result, err := vram.EstimateModelMultiContext(estCtx, vram.ModelEstimateInput{ + Files: files, + }, []uint32{8192}) if err == nil { if result.SizeBytes > 0 { resp.EstimatedSizeBytes = result.SizeBytes resp.EstimatedSizeDisplay = result.SizeDisplay } - if result.VRAMBytes > 0 { - resp.EstimatedVRAMBytes = result.VRAMBytes - resp.EstimatedVRAMDisplay = result.VRAMDisplay + if v := result.VRAMForContext(8192); v > 0 { + resp.EstimatedVRAMBytes = v + resp.EstimatedVRAMDisplay = vram.FormatBytes(v) } } } diff --git a/core/http/endpoints/localai/vram.go b/core/http/endpoints/localai/vram.go index fe7b312bef80..5ac7b0fcaf40 100644 --- a/core/http/endpoints/localai/vram.go +++ b/core/http/endpoints/localai/vram.go @@ -2,9 +2,9 @@ package localai import ( "context" - "fmt" "net/http" "path/filepath" + "slices" "strings" "time" @@ -14,16 +14,10 @@ import ( ) type vramEstimateRequest struct { - Model string `json:"model"` // model name (must be installed) - ContextSize uint32 `json:"context_size,omitempty"` // context length to estimate for (default 8192) - GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all) - KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16) -} - -type vramEstimateResponse struct { - vram.EstimateResult - ContextNote string `json:"context_note,omitempty"` // note when context_size was defaulted - ModelMaxContext uint64 `json:"model_max_context,omitempty"` // model's trained maximum context length + Model string `json:"model"` // model name (must be installed) + ContextSizes []uint32 `json:"context_sizes,omitempty"` // context sizes to estimate (default [8192]) + GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all) + KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16) } // resolveModelURI converts a relative model path to a file:// URI so the @@ -36,8 +30,8 @@ func resolveModelURI(uri, modelsPath string) string { return "file://" + filepath.Join(modelsPath, uri) } -// addWeightFile appends a resolved weight file to files and tracks the first GGUF. -func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, firstGGUF *string, seen map[string]bool) { +// addWeightFile appends a resolved weight file to files. +func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, seen map[string]bool) { if !vram.IsWeightFile(uri) { return } @@ -47,21 +41,17 @@ func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, firstGGUF *s } seen[resolved] = true *files = append(*files, vram.FileInput{URI: resolved, Size: 0}) - if *firstGGUF == "" && vram.IsGGUF(uri) { - *firstGGUF = resolved - } } // VRAMEstimateEndpoint returns a handler that estimates VRAM usage for an -// installed model configuration. For uninstalled models (gallery URLs), use -// the gallery-level estimates in /api/models instead. +// installed model configuration at multiple context sizes. // @Summary Estimate VRAM usage for a model -// @Description Estimates VRAM based on model weight files, context size, and GPU layers +// @Description Estimates VRAM based on model weight files at multiple context sizes // @Tags config // @Accept json // @Produce json // @Param request body vramEstimateRequest true "VRAM estimation parameters" -// @Success 200 {object} vramEstimateResponse "VRAM estimate" +// @Success 200 {object} vram.MultiContextEstimate "VRAM estimate" // @Router /api/models/vram-estimate [post] func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { @@ -82,17 +72,16 @@ func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic modelsPath := appConfig.SystemState.Model.ModelsPath var files []vram.FileInput - var firstGGUF string seen := make(map[string]bool) for _, f := range modelConfig.DownloadFiles { - addWeightFile(string(f.URI), modelsPath, &files, &firstGGUF, seen) + addWeightFile(string(f.URI), modelsPath, &files, seen) } if modelConfig.Model != "" { - addWeightFile(modelConfig.Model, modelsPath, &files, &firstGGUF, seen) + addWeightFile(modelConfig.Model, modelsPath, &files, seen) } if modelConfig.MMProj != "" { - addWeightFile(modelConfig.MMProj, modelsPath, &files, &firstGGUF, seen) + addWeightFile(modelConfig.MMProj, modelsPath, &files, seen) } if len(files) == 0 { @@ -101,45 +90,36 @@ func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic }) } - contextDefaulted := false - opts := vram.EstimateOptions{ - ContextLength: req.ContextSize, - GPULayers: req.GPULayers, - KVQuantBits: req.KVQuantBits, - } - if opts.ContextLength == 0 { + contextSizes := req.ContextSizes + if len(contextSizes) == 0 { if modelConfig.ContextSize != nil { - opts.ContextLength = uint32(*modelConfig.ContextSize) + contextSizes = []uint32{uint32(*modelConfig.ContextSize)} } else { - opts.ContextLength = 8192 - contextDefaulted = true + contextSizes = []uint32{8192} + } + } + + // Include model's configured context size alongside requested sizes + if modelConfig.ContextSize != nil { + modelCtx := uint32(*modelConfig.ContextSize) + if !slices.Contains(contextSizes, modelCtx) { + contextSizes = append(contextSizes, modelCtx) } } + opts := vram.EstimateOptions{ + GPULayers: req.GPULayers, + KVQuantBits: req.KVQuantBits, + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - result, err := vram.Estimate(ctx, files, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader()) + result, err := vram.EstimateMultiContext(ctx, files, contextSizes, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader()) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()}) } - resp := vramEstimateResponse{EstimateResult: result} - - // When context was defaulted to 8192, read the GGUF metadata to report - // the model's trained maximum context length so callers know the estimate - // may be conservative. - if contextDefaulted && firstGGUF != "" { - ggufMeta, err := vram.DefaultCachedGGUFReader().ReadMetadata(ctx, firstGGUF) - if err == nil && ggufMeta != nil && ggufMeta.MaximumContextLength > 0 { - resp.ModelMaxContext = ggufMeta.MaximumContextLength - resp.ContextNote = fmt.Sprintf( - "Estimate used default context_size=8192. The model's trained maximum context is %d; VRAM usage will be higher at larger context sizes.", - ggufMeta.MaximumContextLength, - ) - } - } - - return c.JSON(http.StatusOK, resp) + return c.JSON(http.StatusOK, result) } } diff --git a/core/http/react-ui/src/pages/Models.jsx b/core/http/react-ui/src/pages/Models.jsx index 2192710befea..2aaab9d5ff6d 100644 --- a/core/http/react-ui/src/pages/Models.jsx +++ b/core/http/react-ui/src/pages/Models.jsx @@ -86,6 +86,9 @@ function GalleryLoader() { } +const CONTEXT_SIZES = [8192, 16384, 32768, 65536, 131072, 262144] +const CONTEXT_LABELS = ['8K', '16K', '32K', '64K', '128K', '256K'] + const FILTERS = [ { key: '', label: 'All', icon: 'fa-layer-group' }, { key: 'chat', label: 'Chat', icon: 'fa-brain' }, @@ -119,6 +122,7 @@ export default function Models() { const [allBackends, setAllBackends] = useState([]) const [backendUsecases, setBackendUsecases] = useState({}) const [estimates, setEstimates] = useState({}) + const [contextSize, setContextSize] = useState(CONTEXT_SIZES[0]) const debounceRef = useRef(null) const [confirmDialog, setConfirmDialog] = useState(null) @@ -190,9 +194,9 @@ export default function Models() { models.forEach(model => { const id = model.name || model.id if (estimates[id]) return - modelsApi.estimate(id).then(est => { + modelsApi.estimate(id, CONTEXT_SIZES).then(est => { if (cancelled) return - if (est && (est.SizeBytes || est.VRAMBytes)) { + if (est && (est.sizeBytes || est.estimates)) { setEstimates(prev => ({ ...prev, [id]: est })) } }).catch(() => {}) @@ -371,6 +375,25 @@ export default function Models() { )}
+ {/* Context size slider for VRAM estimates */} +
+ + setContextSize(CONTEXT_SIZES[e.target.value])} + style={{ width: 140, accentColor: 'var(--color-primary)' }} + /> + + {CONTEXT_LABELS[CONTEXT_SIZES.indexOf(contextSize)]} + +
+ {/* Table */} {loading ? ( @@ -415,10 +438,11 @@ export default function Models() { {models.map((model, idx) => { const name = model.name || model.id - const est = estimates[name] || {} - const sizeDisplay = est.SizeDisplay || model.estimated_size_display - const vramDisplay = est.VRAMDisplay || model.estimated_vram_display - const vramBytes = est.VRAMBytes || model.estimated_vram_bytes + const estData = estimates[name] + const sizeDisplay = estData?.sizeDisplay + const ctxEst = estData?.estimates?.[String(contextSize)] + const vramDisplay = ctxEst?.vramDisplay + const vramBytes = ctxEst?.vramBytes const installing = isInstalling(name) const progress = getOperationProgress(name) const fit = fitsGpu(vramBytes) diff --git a/core/http/react-ui/src/utils/api.js b/core/http/react-ui/src/utils/api.js index ad3aba85db4b..8b22d5706d21 100644 --- a/core/http/react-ui/src/utils/api.js +++ b/core/http/react-ui/src/utils/api.js @@ -79,7 +79,10 @@ export const modelsApi = { listCapabilities: () => fetchJSON(API_CONFIG.endpoints.modelsCapabilities), install: (id) => postJSON(API_CONFIG.endpoints.installModel(id), {}), delete: (id) => postJSON(API_CONFIG.endpoints.deleteModel(id), {}), - estimate: (id) => fetchJSON(API_CONFIG.endpoints.modelEstimate(id)), + estimate: (id, contexts) => fetchJSON( + buildUrl(API_CONFIG.endpoints.modelEstimate(id), + contexts?.length ? { contexts: contexts.join(',') } : {}) + ), getConfig: (id) => postJSON(API_CONFIG.endpoints.modelConfig(id), {}), getConfigJson: (name) => fetchJSON(API_CONFIG.endpoints.modelConfigJson(name)), getJob: (uid) => fetchJSON(API_CONFIG.endpoints.modelJob(uid)), diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go index 0eca3093ada5..497964b7c716 100644 --- a/core/http/routes/ui_api.go +++ b/core/http/routes/ui_api.go @@ -9,7 +9,6 @@ import ( "math" "net/http" "net/url" - "path" "slices" "strconv" "strings" @@ -40,8 +39,6 @@ const ( multimodalFilterKey = "multimodal" ) -var galleryWeightExts = map[string]bool{".gguf": true, ".safetensors": true, ".bin": true, ".pt": true} - // usecaseFilters maps UI filter keys to ModelConfigUsecase flags for // capability-based gallery filtering. var usecaseFilters = map[string]config.ModelConfigUsecase{ @@ -77,20 +74,37 @@ func extractHFRepo(overrides map[string]any, urls []string) string { // buildEstimateInput creates a vram.ModelEstimateInput from gallery model metadata. func buildEstimateInput(m *gallery.GalleryModel) vram.ModelEstimateInput { var input vram.ModelEstimateInput - input.Options = vram.EstimateOptions{ContextLength: 8192} input.Size = m.Size if hfRepoID := extractHFRepo(m.Overrides, m.URLs); hfRepoID != "" { input.HFRepo = hfRepoID } for _, f := range m.AdditionalFiles { - ext := strings.ToLower(path.Ext(path.Base(f.URI))) - if galleryWeightExts[ext] { + if vram.IsWeightFile(f.URI) { input.Files = append(input.Files, vram.FileInput{URI: f.URI, Size: 0}) } } return input } +// parseContextSizes parses a comma-separated list of context sizes from a query param. +// Returns a default of [8192] if the param is empty or unparseable. +func parseContextSizes(raw string) []uint32 { + if raw == "" { + return []uint32{8192} + } + var sizes []uint32 + for _, s := range strings.Split(raw, ",") { + s = strings.TrimSpace(s) + if v, err := strconv.ParseUint(s, 10, 32); err == nil && v > 0 { + sizes = append(sizes, uint32(v)) + } + } + if len(sizes) == 0 { + return []uint32{8192} + } + return sizes +} + // getDirectorySize calculates the total size of files in a directory func getDirectorySize(path string) (int64, error) { var totalSize int64 @@ -538,19 +552,18 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model return c.JSON(200, result) }, adminMiddleware) - // Returns VRAM/size estimates for a single gallery model. The frontend - // calls this per-model so the gallery page can load instantly and fill - // in estimates asynchronously. + // Returns VRAM/size estimates for a single gallery model at multiple + // context sizes. The frontend calls this per-model so the gallery page + // can load instantly and fill in estimates asynchronously. + // Query params: + // contexts - comma-separated context sizes (default: 8192) app.GET("/api/models/estimate/:id", func(c echo.Context) error { modelID, err := url.QueryUnescape(c.Param("id")) if err != nil { return c.JSON(http.StatusBadRequest, map[string]any{"error": "invalid model ID"}) } - // Return cached result immediately if available. - if cached, ok := vram.GetCachedEstimate(modelID); ok { - return c.JSON(200, cached) - } + contextSizes := parseContextSizes(c.QueryParam("contexts")) // Look up the model from the gallery to build the estimate input. models, err := gallery.AvailableGalleryModelsCached(appConfig.Galleries, appConfig.SystemState) @@ -565,18 +578,17 @@ func RegisterUIAPIRoutes(app *echo.Echo, cl *config.ModelConfigLoader, ml *model input := buildEstimateInput(model) if len(input.Files) == 0 && input.HFRepo == "" && input.Size == "" { - return c.JSON(200, vram.EstimateResult{}) + return c.JSON(200, vram.MultiContextEstimate{}) } ctx, cancel := context.WithTimeout(c.Request().Context(), 10*time.Second) defer cancel() - result, err := vram.EstimateModel(ctx, input) + result, err := vram.EstimateModelMultiContext(ctx, input, contextSizes) if err != nil { xlog.Debug("model estimate failed", "model", modelID, "error", err) - return c.JSON(200, vram.EstimateResult{}) + return c.JSON(200, vram.MultiContextEstimate{}) } - vram.SetCachedEstimate(modelID, result) return c.JSON(200, result) }, adminMiddleware) diff --git a/core/services/nodes/router.go b/core/services/nodes/router.go index 8e7cc0359f75..20bb25142053 100644 --- a/core/services/nodes/router.go +++ b/core/services/nodes/router.go @@ -396,10 +396,14 @@ func (r *SmartRouter) estimateModelVRAM(ctx context.Context, opts *pb.ModelOptio estCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() + ctxSize := uint32(opts.ContextSize) + if ctxSize == 0 { + ctxSize = 8192 + } + input := vram.ModelEstimateInput{ Options: vram.EstimateOptions{ - ContextLength: uint32(opts.ContextSize), - GPULayers: int(opts.NGPULayers), + GPULayers: int(opts.NGPULayers), }, } @@ -417,28 +421,15 @@ func (r *SmartRouter) estimateModelVRAM(ctx context.Context, opts *pb.ModelOptio } } - // If model file exists, get its size as fallback - if opts.ModelFile != "" && len(input.Files) == 0 { - if info, err := os.Stat(opts.ModelFile); err == nil { - return vram.EstimateFromSize(uint64(info.Size())).VRAMBytes - } - } - if len(input.Files) == 0 && input.HFRepo == "" && input.Size == "" { return 0 } - result, err := vram.EstimateModel(estCtx, input) - if err != nil || result.VRAMBytes == 0 { - // Last resort: try model file size - if opts.ModelFile != "" { - if info, statErr := os.Stat(opts.ModelFile); statErr == nil { - return vram.EstimateFromSize(uint64(info.Size())).VRAMBytes - } - } + result, err := vram.EstimateModelMultiContext(estCtx, input, []uint32{ctxSize}) + if err != nil { return 0 } - return result.VRAMBytes + return result.VRAMForContext(ctxSize) } // installBackendOnNode sends a NATS backend.install request-reply to the node. diff --git a/pkg/vram/cache.go b/pkg/vram/cache.go index 08cf1b7ae7ba..cbfaefed1b94 100644 --- a/pkg/vram/cache.go +++ b/pkg/vram/cache.go @@ -3,135 +3,93 @@ package vram import ( "context" "sync" - "time" ) -const defaultEstimateCacheTTL = 15 * time.Minute +// galleryGenFunc returns the current gallery generation counter. +// When set, cache entries are invalidated when the generation changes. +// When nil (e.g., in tests or non-gallery contexts), entries never expire. +var galleryGenFunc func() uint64 + +// SetGalleryGenerationFunc wires the gallery generation counter into the +// VRAM caches. Call this once at application startup. +func SetGalleryGenerationFunc(fn func() uint64) { + galleryGenFunc = fn +} + +func currentGeneration() uint64 { + if galleryGenFunc != nil { + return galleryGenFunc() + } + return 0 +} type sizeCacheEntry struct { - size int64 - err error - until time.Time + size int64 + err error + generation uint64 } type cachedSizeResolver struct { underlying SizeResolver - ttl time.Duration mu sync.Mutex cache map[string]sizeCacheEntry } func (c *cachedSizeResolver) ContentLength(ctx context.Context, uri string) (int64, error) { + gen := currentGeneration() c.mu.Lock() e, ok := c.cache[uri] c.mu.Unlock() - if ok && time.Now().Before(e.until) { + if ok && e.generation == gen { return e.size, e.err } size, err := c.underlying.ContentLength(ctx, uri) c.mu.Lock() - if c.cache == nil { - c.cache = make(map[string]sizeCacheEntry) - } - c.cache[uri] = sizeCacheEntry{size: size, err: err, until: time.Now().Add(c.ttl)} + c.cache[uri] = sizeCacheEntry{size: size, err: err, generation: gen} c.mu.Unlock() return size, err } type ggufCacheEntry struct { - meta *GGUFMeta - err error - until time.Time + meta *GGUFMeta + err error + generation uint64 } type cachedGGUFReader struct { underlying GGUFMetadataReader - ttl time.Duration mu sync.Mutex cache map[string]ggufCacheEntry } func (c *cachedGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta, error) { + gen := currentGeneration() c.mu.Lock() e, ok := c.cache[uri] c.mu.Unlock() - if ok && time.Now().Before(e.until) { + if ok && e.generation == gen { return e.meta, e.err } meta, err := c.underlying.ReadMetadata(ctx, uri) c.mu.Lock() - if c.cache == nil { - c.cache = make(map[string]ggufCacheEntry) - } - c.cache[uri] = ggufCacheEntry{meta: meta, err: err, until: time.Now().Add(c.ttl)} + c.cache[uri] = ggufCacheEntry{meta: meta, err: err, generation: gen} c.mu.Unlock() return meta, err } -// CachedSizeResolver returns a SizeResolver that caches ContentLength results by URI for the given TTL. -func CachedSizeResolver(underlying SizeResolver, ttl time.Duration) SizeResolver { - return &cachedSizeResolver{underlying: underlying, ttl: ttl, cache: make(map[string]sizeCacheEntry)} -} - -// CachedGGUFReader returns a GGUFMetadataReader that caches ReadMetadata results by URI for the given TTL. -func CachedGGUFReader(underlying GGUFMetadataReader, ttl time.Duration) GGUFMetadataReader { - return &cachedGGUFReader{underlying: underlying, ttl: ttl, cache: make(map[string]ggufCacheEntry)} -} - -// DefaultCachedSizeResolver returns a cached SizeResolver using the default implementation and default TTL (15 min). -// A single shared cache is used so repeated HEAD requests for the same URI are avoided across requests. +// DefaultCachedSizeResolver returns a cached SizeResolver using the default implementation. +// Entries are invalidated when the gallery generation changes. func DefaultCachedSizeResolver() SizeResolver { return defaultCachedSizeResolver } -// DefaultCachedGGUFReader returns a cached GGUFMetadataReader using the default implementation and default TTL (15 min). -// A single shared cache is used so repeated GGUF metadata fetches for the same URI are avoided across requests. +// DefaultCachedGGUFReader returns a cached GGUFMetadataReader using the default implementation. +// Entries are invalidated when the gallery generation changes. func DefaultCachedGGUFReader() GGUFMetadataReader { return defaultCachedGGUFReader } var ( - defaultCachedSizeResolver = CachedSizeResolver(defaultSizeResolver{}, defaultEstimateCacheTTL) - defaultCachedGGUFReader = CachedGGUFReader(defaultGGUFReader{}, defaultEstimateCacheTTL) -) - -// Model-level estimate result cache — keyed by model ID, avoids re-running -// the full estimation pipeline (HTTP HEAD, GGUF reads, HF API) on every -// gallery page load. - -const estimateResultTTL = 1 * time.Hour - -type estimateResultEntry struct { - result EstimateResult - until time.Time -} - -var ( - estimateResultMu sync.Mutex - estimateResultCache = make(map[string]estimateResultEntry) + defaultCachedSizeResolver = &cachedSizeResolver{underlying: defaultSizeResolver{}, cache: make(map[string]sizeCacheEntry)} + defaultCachedGGUFReader = &cachedGGUFReader{underlying: defaultGGUFReader{}, cache: make(map[string]ggufCacheEntry)} ) - -// GetCachedEstimate returns a previously cached EstimateResult for the given -// key (typically a model ID). Returns false on cache miss or expiry. -func GetCachedEstimate(key string) (EstimateResult, bool) { - estimateResultMu.Lock() - defer estimateResultMu.Unlock() - e, ok := estimateResultCache[key] - if !ok || time.Now().After(e.until) { - if ok { - delete(estimateResultCache, key) - } - return EstimateResult{}, false - } - return e.result, true -} - -// SetCachedEstimate stores an EstimateResult for the given key with a 1-hour TTL. -func SetCachedEstimate(key string, result EstimateResult) { - estimateResultMu.Lock() - defer estimateResultMu.Unlock() - estimateResultCache[key] = estimateResultEntry{ - result: result, - until: time.Now().Add(estimateResultTTL), - } -} diff --git a/pkg/vram/estimate.go b/pkg/vram/estimate.go index f98517ab07fa..c91004a4bcd2 100644 --- a/pkg/vram/estimate.go +++ b/pkg/vram/estimate.go @@ -23,17 +23,19 @@ func IsGGUF(nameOrURI string) bool { return strings.ToLower(path.Ext(path.Base(nameOrURI))) == ".gguf" } -func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, sizeResolver SizeResolver, ggufReader GGUFMetadataReader) (EstimateResult, error) { - if opts.ContextLength == 0 { - opts.ContextLength = 8192 - } - if opts.KVQuantBits == 0 { - opts.KVQuantBits = 16 - } +// modelProfile captures the "fixed" properties of a model after I/O. +// Everything except context length is constant for a given model. +type modelProfile struct { + sizeBytes uint64 // total weight file size + ggufSize uint64 // GGUF file size (subset of sizeBytes) + meta *GGUFMeta // nil if no GGUF metadata available +} - var sizeBytes uint64 - var ggufSize uint64 +// resolveProfile does all I/O: iterates files, fetches sizes and GGUF metadata. +func resolveProfile(ctx context.Context, files []FileInput, sizeResolver SizeResolver, ggufReader GGUFMetadataReader) modelProfile { + var p modelProfile var firstGGUFURI string + for i := range files { f := &files[i] if !IsWeightFile(f.URI) { @@ -47,23 +49,32 @@ func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, size continue } } - sizeBytes += uint64(sz) + p.sizeBytes += uint64(sz) if IsGGUF(f.URI) { - ggufSize += uint64(sz) + p.ggufSize += uint64(sz) if firstGGUFURI == "" { firstGGUFURI = f.URI } } } - sizeDisplay := FormatBytes(sizeBytes) + if p.ggufSize > 0 && ggufReader != nil && firstGGUFURI != "" { + p.meta, _ = ggufReader.ReadMetadata(ctx, firstGGUFURI) + } - var vramBytes uint64 - if ggufSize > 0 { - var meta *GGUFMeta - if ggufReader != nil && firstGGUFURI != "" { - meta, _ = ggufReader.ReadMetadata(ctx, firstGGUFURI) - } + return p +} + +// computeVRAM is pure arithmetic — no I/O. Returns VRAM bytes for a given +// model profile and context length. +func computeVRAM(p modelProfile, ctxLen uint32, opts EstimateOptions) uint64 { + kvQuantBits := opts.KVQuantBits + if kvQuantBits == 0 { + kvQuantBits = 16 + } + + if p.ggufSize > 0 { + meta := p.meta if meta != nil && (meta.BlockCount > 0 || meta.EmbeddingLength > 0) { nLayers := meta.BlockCount if nLayers == 0 { @@ -84,36 +95,29 @@ func Estimate(ctx context.Context, files []FileInput, opts EstimateOptions, size if gpuLayers <= 0 { gpuLayers = int(nLayers) } - ctxLen := opts.ContextLength - bKV := uint32(opts.KVQuantBits / 8) + bKV := uint32(kvQuantBits / 8) if bKV == 0 { bKV = 4 } - M_model := ggufSize - M_KV := uint64(bKV) * uint64(dModel) * uint64(nLayers) * uint64(ctxLen) - if headCountKV > 0 && meta.HeadCount > 0 { - M_KV = uint64(bKV) * uint64(dModel) * uint64(headCountKV) * uint64(ctxLen) - } + + M_model := p.ggufSize + M_KV := uint64(bKV) * uint64(dModel) * uint64(headCountKV) * uint64(ctxLen) P := M_model * 2 M_overhead := uint64(0.02*float64(P) + 0.15*1e9) - vramBytes = M_model + M_KV + M_overhead + vramBytes := M_model + M_KV + M_overhead if nLayers > 0 && gpuLayers < int(nLayers) { layerRatio := float64(gpuLayers) / float64(nLayers) vramBytes = uint64(layerRatio*float64(M_model)) + M_KV + M_overhead } - } else { - vramBytes = sizeOnlyVRAM(ggufSize, opts.ContextLength) + return vramBytes } - } else if sizeBytes > 0 { - vramBytes = sizeOnlyVRAM(sizeBytes, opts.ContextLength) + return sizeOnlyVRAM(p.ggufSize, ctxLen) } - return EstimateResult{ - SizeBytes: sizeBytes, - SizeDisplay: sizeDisplay, - VRAMBytes: vramBytes, - VRAMDisplay: FormatBytes(vramBytes), - }, nil + if p.sizeBytes > 0 { + return sizeOnlyVRAM(p.sizeBytes, ctxLen) + } + return 0 } func sizeOnlyVRAM(sizeOnDisk uint64, ctxLen uint32) uint64 { @@ -125,6 +129,45 @@ func sizeOnlyVRAM(sizeOnDisk uint64, ctxLen uint32) uint64 { return vram } +// buildEstimates computes VRAMAt entries for each context size from a profile. +func buildEstimates(p modelProfile, contextSizes []uint32, opts EstimateOptions) map[string]VRAMAt { + m := make(map[string]VRAMAt, len(contextSizes)) + for _, ctxLen := range contextSizes { + vramBytes := computeVRAM(p, ctxLen, opts) + m[fmt.Sprint(ctxLen)] = VRAMAt{ + ContextLength: ctxLen, + VRAMBytes: vramBytes, + VRAMDisplay: FormatBytes(vramBytes), + } + } + return m +} + + +// EstimateMultiContext estimates model size and VRAM at multiple context sizes. +// It performs I/O once (resolveProfile) then computes VRAM for each context size. +func EstimateMultiContext(ctx context.Context, files []FileInput, contextSizes []uint32, + opts EstimateOptions, sizeResolver SizeResolver, ggufReader GGUFMetadataReader) (MultiContextEstimate, error) { + + if len(contextSizes) == 0 { + contextSizes = []uint32{8192} + } + + p := resolveProfile(ctx, files, sizeResolver, ggufReader) + + result := MultiContextEstimate{ + SizeBytes: p.sizeBytes, + SizeDisplay: FormatBytes(p.sizeBytes), + Estimates: buildEstimates(p, contextSizes, opts), + } + + if p.meta != nil && p.meta.MaximumContextLength > 0 { + result.ModelMaxContext = p.meta.MaximumContextLength + } + + return result, nil +} + // ParseSizeString parses a human-readable size string (e.g. "500MB", "14.5 GB", "2tb") // into bytes. Supports B, KB, MB, GB, TB, PB (case-insensitive, space optional). // Uses SI units (1 KB = 1000 B). @@ -136,7 +179,6 @@ func ParseSizeString(s string) (uint64, error) { s = strings.ToUpper(s) - // Find where the numeric part ends i := 0 for i < len(s) && (s[i] == '.' || (s[i] >= '0' && s[i] <= '9')) { i++ @@ -177,17 +219,6 @@ func ParseSizeString(s string) (uint64, error) { return uint64(num * float64(multiplier)), nil } -// EstimateFromSize builds an EstimateResult from a raw byte count. -func EstimateFromSize(sizeBytes uint64) EstimateResult { - vramBytes := sizeOnlyVRAM(sizeBytes, 8192) - return EstimateResult{ - SizeBytes: sizeBytes, - SizeDisplay: FormatBytes(sizeBytes), - VRAMBytes: vramBytes, - VRAMDisplay: FormatBytes(vramBytes), - } -} - func FormatBytes(n uint64) string { const unit = 1000 if n < unit { @@ -216,24 +247,29 @@ func DefaultGGUFReader() GGUFMetadataReader { } // ModelEstimateInput describes the inputs for a unified VRAM/size estimation. -// The estimator cascades through available data: files → size string → HF repo → zero. +// The estimator cascades through available data: files -> size string -> HF repo -> zero. type ModelEstimateInput struct { Files []FileInput // weight files with optional pre-known sizes Size string // gallery hardcoded size (e.g. "14.5GB") HFRepo string // HF repo ID or URL - Options EstimateOptions // context length, GPU layers, KV quant bits + Options EstimateOptions // GPU layers, KV quant bits } -// EstimateModel provides a unified VRAM estimation entry point. +// EstimateModelMultiContext provides a unified VRAM estimation entry point +// that returns estimates at multiple context sizes. // It tries (in order): // 1. Direct file-based estimation (GGUF metadata or file size heuristic) // 2. ParseSizeString from Size field -// 3. EstimateFromHFRepo +// 3. HuggingFace repo file listing // 4. Zero result -func EstimateModel(ctx context.Context, input ModelEstimateInput) (EstimateResult, error) { +func EstimateModelMultiContext(ctx context.Context, input ModelEstimateInput, contextSizes []uint32) (MultiContextEstimate, error) { + if len(contextSizes) == 0 { + contextSizes = []uint32{8192} + } + // 1. Try direct file estimation if len(input.Files) > 0 { - result, err := Estimate(ctx, input.Files, input.Options, DefaultCachedSizeResolver(), DefaultCachedGGUFReader()) + result, err := EstimateMultiContext(ctx, input.Files, contextSizes, input.Options, DefaultCachedSizeResolver(), DefaultCachedGGUFReader()) if err != nil { xlog.Debug("VRAM estimation from files failed", "error", err) } @@ -247,7 +283,11 @@ func EstimateModel(ctx context.Context, input ModelEstimateInput) (EstimateResul if sizeBytes, err := ParseSizeString(input.Size); err != nil { xlog.Debug("VRAM estimation from size string failed", "error", err, "size", input.Size) } else if sizeBytes > 0 { - return EstimateFromSize(sizeBytes), nil + return MultiContextEstimate{ + SizeBytes: sizeBytes, + SizeDisplay: FormatBytes(sizeBytes), + Estimates: buildEstimates(modelProfile{sizeBytes: sizeBytes}, contextSizes, EstimateOptions{}), + }, nil } } @@ -257,15 +297,19 @@ func EstimateModel(ctx context.Context, input ModelEstimateInput) (EstimateResul hfRepo = repoID } if hfRepo != "" { - result, err := EstimateFromHFRepo(ctx, hfRepo) + totalBytes, err := hfRepoWeightSize(ctx, hfRepo) if err != nil { xlog.Debug("VRAM estimation from HF repo failed", "error", err, "repo", hfRepo) } - if err == nil && result.SizeBytes > 0 { - return result, nil + if err == nil && totalBytes > 0 { + return MultiContextEstimate{ + SizeBytes: totalBytes, + SizeDisplay: FormatBytes(totalBytes), + Estimates: buildEstimates(modelProfile{sizeBytes: totalBytes}, contextSizes, EstimateOptions{}), + }, nil } } // 4. No estimation possible - return EstimateResult{}, nil + return MultiContextEstimate{}, nil } diff --git a/pkg/vram/estimate_test.go b/pkg/vram/estimate_test.go index 2036c8dad460..4431f6fe92f8 100644 --- a/pkg/vram/estimate_test.go +++ b/pkg/vram/estimate_test.go @@ -23,26 +23,25 @@ func (f fakeGGUFReader) ReadMetadata(ctx context.Context, uri string) (*GGUFMeta return f[uri], nil } -var _ = Describe("Estimate", func() { +var _ = Describe("EstimateMultiContext", func() { ctx := context.Background() + defaultCtx := []uint32{8192} Describe("empty or non-GGUF inputs", func() { It("returns zero size and vram for nil files", func() { - opts := EstimateOptions{ContextLength: 8192} - res, err := Estimate(ctx, nil, opts, nil, nil) + res, err := EstimateMultiContext(ctx, nil, defaultCtx, EstimateOptions{}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(res.SizeBytes).To(Equal(uint64(0))) - Expect(res.VRAMBytes).To(Equal(uint64(0))) + Expect(res.Estimates["8192"].VRAMBytes).To(Equal(uint64(0))) Expect(res.SizeDisplay).To(Equal("0 B")) }) - It("counts only .gguf files and ignores other extensions", func() { + It("counts only weight files and ignores other extensions", func() { files := []FileInput{ {URI: "http://a/model.gguf", Size: 1_000_000_000}, {URI: "http://a/readme.txt", Size: 100}, } - opts := EstimateOptions{ContextLength: 8192} - res, err := Estimate(ctx, files, opts, nil, nil) + res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(res.SizeBytes).To(Equal(uint64(1_000_000_000))) }) @@ -52,8 +51,7 @@ var _ = Describe("Estimate", func() { {URI: "http://hf.co/model/model.safetensors", Size: 2_000_000_000}, {URI: "http://hf.co/model/model2.safetensors", Size: 3_000_000_000}, } - opts := EstimateOptions{ContextLength: 8192} - res, err := Estimate(ctx, files, opts, nil, nil) + res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(res.SizeBytes).To(Equal(uint64(5_000_000_000))) }) @@ -62,24 +60,22 @@ var _ = Describe("Estimate", func() { Describe("GGUF size and resolver", func() { It("uses size resolver when file size is not set", func() { sizes := fakeSizeResolver{"http://example.com/model.gguf": 1_500_000_000} - opts := EstimateOptions{ContextLength: 8192} files := []FileInput{{URI: "http://example.com/model.gguf"}} - res, err := Estimate(ctx, files, opts, sizes, nil) + res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, sizes, nil) Expect(err).ToNot(HaveOccurred()) Expect(res.SizeBytes).To(Equal(uint64(1_500_000_000))) - Expect(res.VRAMBytes).To(BeNumerically(">=", res.SizeBytes)) + Expect(res.Estimates["8192"].VRAMBytes).To(BeNumerically(">=", res.SizeBytes)) Expect(res.SizeDisplay).To(Equal("1.5 GB")) }) It("uses size-only VRAM formula when metadata is missing and size is large", func() { sizes := fakeSizeResolver{"http://a/model.gguf": 10_000_000_000} - opts := EstimateOptions{ContextLength: 8192} files := []FileInput{{URI: "http://a/model.gguf"}} - res, err := Estimate(ctx, files, opts, sizes, nil) + res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, sizes, nil) Expect(err).ToNot(HaveOccurred()) - Expect(res.VRAMBytes).To(BeNumerically(">", 10_000_000_000)) + Expect(res.Estimates["8192"].VRAMBytes).To(BeNumerically(">", 10_000_000_000)) }) It("sums size for multiple GGUF shards", func() { @@ -87,18 +83,16 @@ var _ = Describe("Estimate", func() { {URI: "http://a/shard1.gguf", Size: 10_000_000_000}, {URI: "http://a/shard2.gguf", Size: 5_000_000_000}, } - opts := EstimateOptions{ContextLength: 8192} - res, err := Estimate(ctx, files, opts, nil, nil) + res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(res.SizeBytes).To(Equal(uint64(15_000_000_000))) }) It("formats size display correctly", func() { files := []FileInput{{URI: "http://a/model.gguf", Size: 2_500_000_000}} - opts := EstimateOptions{ContextLength: 8192} - res, err := Estimate(ctx, files, opts, nil, nil) + res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(res.SizeDisplay).To(Equal("2.5 GB")) }) @@ -108,24 +102,94 @@ var _ = Describe("Estimate", func() { It("uses metadata for VRAM when reader returns meta and partial offload", func() { meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096} reader := fakeGGUFReader{"http://a/model.gguf": meta} - opts := EstimateOptions{ContextLength: 8192, GPULayers: 20} + opts := EstimateOptions{GPULayers: 20} files := []FileInput{{URI: "http://a/model.gguf", Size: 8_000_000_000}} - res, err := Estimate(ctx, files, opts, nil, reader) + res, err := EstimateMultiContext(ctx, files, defaultCtx, opts, nil, reader) Expect(err).ToNot(HaveOccurred()) - Expect(res.VRAMBytes).To(BeNumerically(">", 0)) + Expect(res.Estimates["8192"].VRAMBytes).To(BeNumerically(">", 0)) }) It("uses metadata head counts for KV and yields vram > size", func() { files := []FileInput{{URI: "http://a/model.gguf", Size: 15_000_000_000}} meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096, HeadCount: 32, HeadCountKV: 8} reader := fakeGGUFReader{"http://a/model.gguf": meta} - opts := EstimateOptions{ContextLength: 8192} - res, err := Estimate(ctx, files, opts, nil, reader) + res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, reader) Expect(err).ToNot(HaveOccurred()) Expect(res.SizeBytes).To(Equal(uint64(15_000_000_000))) - Expect(res.VRAMBytes).To(BeNumerically(">", res.SizeBytes)) + Expect(res.Estimates["8192"].VRAMBytes).To(BeNumerically(">", res.SizeBytes)) + }) + + It("populates ModelMaxContext from GGUF metadata", func() { + meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096, MaximumContextLength: 131072} + reader := fakeGGUFReader{"http://a/model.gguf": meta} + files := []FileInput{{URI: "http://a/model.gguf", Size: 8_000_000_000}} + + res, err := EstimateMultiContext(ctx, files, defaultCtx, EstimateOptions{}, nil, reader) + Expect(err).ToNot(HaveOccurred()) + Expect(res.ModelMaxContext).To(Equal(uint64(131072))) + }) + }) + + Describe("multi-context behavior", func() { + It("returns estimates for all requested context sizes", func() { + files := []FileInput{{URI: "http://a/model.gguf", Size: 4_000_000_000}} + sizes := []uint32{8192, 32768, 131072} + + res, err := EstimateMultiContext(ctx, files, sizes, EstimateOptions{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(res.Estimates).To(HaveLen(3)) + Expect(res.Estimates).To(HaveKey("8192")) + Expect(res.Estimates).To(HaveKey("32768")) + Expect(res.Estimates).To(HaveKey("131072")) + }) + + It("VRAM increases monotonically with context size", func() { + files := []FileInput{{URI: "http://a/model.gguf", Size: 4_000_000_000}} + meta := &GGUFMeta{BlockCount: 32, EmbeddingLength: 4096, HeadCount: 32, HeadCountKV: 8} + reader := fakeGGUFReader{"http://a/model.gguf": meta} + sizes := []uint32{8192, 16384, 32768, 65536, 131072, 262144} + + res, err := EstimateMultiContext(ctx, files, sizes, EstimateOptions{}, nil, reader) + Expect(err).ToNot(HaveOccurred()) + + prev := uint64(0) + for _, sz := range sizes { + v := res.VRAMForContext(sz) + Expect(v).To(BeNumerically(">", prev), "VRAM should increase at context %d", sz) + prev = v + } + }) + + It("size is constant across context sizes", func() { + files := []FileInput{{URI: "http://a/model.gguf", Size: 4_000_000_000}} + sizes := []uint32{8192, 32768} + + res, err := EstimateMultiContext(ctx, files, sizes, EstimateOptions{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(res.SizeBytes).To(Equal(uint64(4_000_000_000))) + }) + + It("defaults to [8192] when contextSizes is empty", func() { + files := []FileInput{{URI: "http://a/model.gguf", Size: 4_000_000_000}} + + res, err := EstimateMultiContext(ctx, files, nil, EstimateOptions{}, nil, nil) + Expect(err).ToNot(HaveOccurred()) + Expect(res.Estimates).To(HaveLen(1)) + Expect(res.Estimates).To(HaveKey("8192")) + }) + }) + + Describe("VRAMForContext helper", func() { + It("returns 0 for missing context size", func() { + res := MultiContextEstimate{ + Estimates: map[string]VRAMAt{ + "8192": {VRAMBytes: 5000}, + }, + } + Expect(res.VRAMForContext(99999)).To(Equal(uint64(0))) + Expect(res.VRAMForContext(8192)).To(Equal(uint64(5000))) }) }) }) diff --git a/pkg/vram/hf_estimate.go b/pkg/vram/hf_estimate.go index 2d1ca9b42d16..9f9ff4e6a773 100644 --- a/pkg/vram/hf_estimate.go +++ b/pkg/vram/hf_estimate.go @@ -4,7 +4,6 @@ import ( "context" "strings" "sync" - "time" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" ) @@ -15,13 +14,11 @@ var ( ) type hfSizeCacheEntry struct { - result EstimateResult - err error - expiresAt time.Time + totalBytes uint64 + err error + generation uint64 } -const hfSizeCacheTTL = 15 * time.Minute - // ExtractHFRepoID extracts a HuggingFace repo ID from a string. // It handles both short form ("org/model") and full URL form // ("https://huggingface.co/org/model", "huggingface.co/org/model"). @@ -62,30 +59,31 @@ func ExtractHFRepoID(s string) (string, bool) { return "", false } -// EstimateFromHFRepo estimates model size by querying the HuggingFace API for file listings. -// Results are cached for 15 minutes. -func EstimateFromHFRepo(ctx context.Context, repoID string) (EstimateResult, error) { +// hfRepoWeightSize returns the total weight file size for a HuggingFace repo. +// Results are cached and invalidated when the gallery generation changes. +func hfRepoWeightSize(ctx context.Context, repoID string) (uint64, error) { + gen := currentGeneration() hfSizeCacheMu.Lock() - if entry, ok := hfSizeCacheData[repoID]; ok && time.Now().Before(entry.expiresAt) { + if entry, ok := hfSizeCacheData[repoID]; ok && entry.generation == gen { hfSizeCacheMu.Unlock() - return entry.result, entry.err + return entry.totalBytes, entry.err } hfSizeCacheMu.Unlock() - result, err := estimateFromHFRepoUncached(ctx, repoID) + totalBytes, err := hfRepoWeightSizeUncached(ctx, repoID) hfSizeCacheMu.Lock() hfSizeCacheData[repoID] = hfSizeCacheEntry{ - result: result, - err: err, - expiresAt: time.Now().Add(hfSizeCacheTTL), + totalBytes: totalBytes, + err: err, + generation: gen, } hfSizeCacheMu.Unlock() - return result, err + return totalBytes, err } -func estimateFromHFRepoUncached(ctx context.Context, repoID string) (EstimateResult, error) { +func hfRepoWeightSizeUncached(ctx context.Context, repoID string) (uint64, error) { client := hfapi.NewClient() type listResult struct { @@ -100,17 +98,17 @@ func estimateFromHFRepoUncached(ctx context.Context, repoID string) (EstimateRes select { case <-ctx.Done(): - return EstimateResult{}, ctx.Err() + return 0, ctx.Err() case res := <-ch: if res.err != nil { - return EstimateResult{}, res.err + return 0, res.err } - return estimateFromFileInfos(res.files), nil + return sumWeightFileBytes(res.files), nil } } -func estimateFromFileInfos(files []hfapi.FileInfo) EstimateResult { - var totalSize int64 +func sumWeightFileBytes(files []hfapi.FileInfo) uint64 { + var total int64 for _, f := range files { if f.Type != "file" { continue @@ -128,20 +126,10 @@ func estimateFromFileInfos(files []hfapi.FileInfo) EstimateResult { if f.LFS != nil && f.LFS.Size > 0 { size = f.LFS.Size } - totalSize += size - } - - if totalSize <= 0 { - return EstimateResult{} + total += size } - - sizeBytes := uint64(totalSize) - vramBytes := sizeOnlyVRAM(sizeBytes, 8192) - - return EstimateResult{ - SizeBytes: sizeBytes, - SizeDisplay: FormatBytes(sizeBytes), - VRAMBytes: vramBytes, - VRAMDisplay: FormatBytes(vramBytes), + if total < 0 { + return 0 } + return uint64(total) } diff --git a/pkg/vram/types.go b/pkg/vram/types.go index 476c50404122..4893342bef74 100644 --- a/pkg/vram/types.go +++ b/pkg/vram/types.go @@ -1,6 +1,9 @@ package vram -import "context" +import ( + "context" + "fmt" +) // FileInput represents a single model file for estimation (URI and optional pre-known size). type FileInput struct { @@ -28,16 +31,33 @@ type GGUFMetadataReader interface { } // EstimateOptions configures VRAM/size estimation. +// GPULayers and KVQuantBits apply uniformly across all context sizes. type EstimateOptions struct { - ContextLength uint32 - GPULayers int - KVQuantBits int + GPULayers int + KVQuantBits int } -// EstimateResult holds estimated download size and VRAM with display strings. -type EstimateResult struct { - SizeBytes uint64 `json:"sizeBytes"` // total model weight size in bytes - SizeDisplay string `json:"sizeDisplay"` // human-readable size (e.g. "4.2 GB") - VRAMBytes uint64 `json:"vramBytes"` // estimated VRAM usage in bytes - VRAMDisplay string `json:"vramDisplay"` // human-readable VRAM (e.g. "6.1 GB") +// VRAMAt holds the VRAM estimate at a specific context size. +type VRAMAt struct { + ContextLength uint32 `json:"contextLength"` + VRAMBytes uint64 `json:"vramBytes"` + VRAMDisplay string `json:"vramDisplay"` +} + +// MultiContextEstimate holds VRAM estimates for one or more context sizes, +// computed from a single metadata fetch. +type MultiContextEstimate struct { + SizeBytes uint64 `json:"sizeBytes"` + SizeDisplay string `json:"sizeDisplay"` + Estimates map[string]VRAMAt `json:"estimates"` // keys: context size as string + ModelMaxContext uint64 `json:"modelMaxContext,omitempty"` // from GGUF metadata +} + +// VRAMForContext is a convenience method that returns the VRAMBytes for a +// specific context size, or 0 if not present. +func (m MultiContextEstimate) VRAMForContext(ctxLen uint32) uint64 { + if e, ok := m.Estimates[fmt.Sprint(ctxLen)]; ok { + return e.VRAMBytes + } + return 0 }