diff --git a/CHANGELOG.md b/CHANGELOG.md index f4a0b55d3..d2e4937c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- fix: Handle embedding models without KV memory and test embeddings with a real GGUF embedding model by @abetlen in #2160 - fix(ci): Shrink CUDA wheel fatbins so CUDA releases stay under GitHub's asset size limit by @abetlen in #2158 ## [0.3.18] diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 6862135aa..9e9bcd407 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -288,7 +288,9 @@ def pooling_type(self) -> int: return llama_cpp.llama_pooling_type(self.ctx) def kv_cache_clear(self): - assert self.memory is not None, "Memory is not initialized" + # Embedding models with non-causal attention may not allocate memory. + if self.memory is None: + return llama_cpp.llama_memory_clear(self.memory, True) def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int) -> bool: diff --git a/tests/test_llama.py b/tests/test_llama.py index 1a70c74d4..23928fff6 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -64,6 +64,14 @@ def llama_cpp_model_path(): return model_path +@pytest.fixture +def llama_cpp_embedding_model_path(): + repo_id = "CompendiumLabs/bge-small-en-v1.5-gguf" + filename = "bge-small-en-v1.5-q4_k_m.gguf" + model_path = hf_hub_download(repo_id, filename) + return model_path + + def test_real_model(llama_cpp_model_path): import os @@ -225,9 +233,9 @@ def logit_processor_func(input_ids, logits): assert number_1 == number_3 -def test_real_llama_embeddings(llama_cpp_model_path): +def test_real_llama_embeddings(llama_cpp_embedding_model_path): model = llama_cpp.Llama( - llama_cpp_model_path, + llama_cpp_embedding_model_path, n_ctx=32, n_batch=32, n_ubatch=32, @@ -237,5 +245,5 @@ def test_real_llama_embeddings(llama_cpp_model_path): flash_attn=True, embedding=True, ) - # Smoke test for now - model.embed("Hello World") + embedding = model.embed("Hello World") + assert len(embedding) > 0