From 5cfa76597430f4469a0889dc74c0152e50e5b3b8 Mon Sep 17 00:00:00 2001 From: yash solanki Date: Sat, 7 Mar 2026 12:56:41 +0530 Subject: [PATCH 1/3] runtime: add LoRA adapter metadata to paged kv cache --- src/runtime/vm/attn_utils.h | 2 ++ src/runtime/vm/kv_state.cc | 6 ++++ src/runtime/vm/kv_state.h | 20 +++++++++++ src/runtime/vm/paged_kv_cache.cc | 34 ++++++++++++++++++- ...me_builtin_paged_attention_kv_cache_cpu.py | 34 +++++++++++++++++++ 5 files changed, 95 insertions(+), 1 deletion(-) diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index afb962e4fc6f..e9de65215c36 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -180,6 +180,8 @@ struct Sequence { * in the KV cache even when sliding window is enabled. */ int last_block_attn_sink_size = 0; + /*! \brief The LoRA adapter id associated with the sequence. */ + int32_t lora_adapter_id = 0; /*! \brief Whether the current appended tokens form a chain (not a tree). */ bool is_chain = true; diff --git a/src/runtime/vm/kv_state.cc b/src/runtime/vm/kv_state.cc index b82d934f5c67..906913c10b92 100644 --- a/src/runtime/vm/kv_state.cc +++ b/src/runtime/vm/kv_state.cc @@ -65,6 +65,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { &AttentionKVCacheObj::EnableSlidingWindowForSeq) .def_method("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes", &AttentionKVCacheObj::CommitAcceptedTokenTreeNodes) + .def_method("vm.builtin.attention_kv_cache_set_sequence_lora_adapter", + &AttentionKVCacheObj::SetSequenceLoraAdapter) + .def_method("vm.builtin.attention_kv_cache_get_sequence_lora_adapter", + &AttentionKVCacheObj::GetSequenceLoraAdapter) + .def_method("vm.builtin.attention_kv_cache_get_current_lora_adapter_ids", + &AttentionKVCacheObj::GetCurrentLoraAdapterIds) .def_method("vm.builtin.attention_kv_cache_empty", &AttentionKVCacheObj::Empty) .def_method("vm.builtin.attention_kv_cache_get_num_available_pages", &AttentionKVCacheObj::GetNumAvailablePages) diff --git a/src/runtime/vm/kv_state.h b/src/runtime/vm/kv_state.h index 33c669f18ab2..9573e6bb70f1 100644 --- a/src/runtime/vm/kv_state.h +++ b/src/runtime/vm/kv_state.h @@ -158,6 +158,26 @@ class AttentionKVCacheObj : public KVStateObj { virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple& leaf_indices) = 0; + /*! + * \brief Set the LoRA adapter id for the specified sequence. + * \param seq_id The id of the sequence to update. + * \param lora_adapter_id The LoRA adapter id to associate with the sequence. + */ + virtual void SetSequenceLoraAdapter(int64_t seq_id, int64_t lora_adapter_id) = 0; + + /*! + * \brief Get the LoRA adapter id for the specified sequence. + * \param seq_id The id of the sequence to query. + * \return The LoRA adapter id associated with the sequence. + */ + virtual int64_t GetSequenceLoraAdapter(int64_t seq_id) = 0; + + /*! + * \brief Get the LoRA adapter ids of the current batch specified in BeginForward. + * \return The LoRA adapter ids, in the same order as the current batch sequence ids. + */ + virtual ffi::Shape GetCurrentLoraAdapterIds() = 0; + /*! \brief Prepare for the disaggregation KV data receive for the specified sequence and length.*/ virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0; diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 36f7697237e2..958047226f13 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -206,6 +206,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector k_rope_pos_offset_on_depths_host_; std::vector k_rope_pos_offset_sliding_window_on_depths_host_; HostMemoryVector k_ragged_rope_pos_offset_host_; + HostMemoryVector current_lora_adapter_ids_host_; HostMemoryVector q_rope_position_map_host_; HostMemoryVector append_position_map_host_; HostMemoryVector cur_append_lengths_indptr_host_; @@ -414,6 +415,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } k_ragged_rope_pos_offset_host_ = HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device); + current_lora_adapter_ids_host_ = + HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device); q_rope_position_map_host_ = HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device); append_position_map_host_ = @@ -685,7 +688,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { break; } // Create the child sequence with the child block. - seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)}); + auto [child_it, inserted] = + seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)}); + ICHECK(inserted); + child_it->second.lora_adapter_id = parent_it->second.lora_adapter_id; dirty_aux_data_device_ = true; } @@ -874,11 +880,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { sequences.reserve(cur_batch_size_); last_block_length_before_append.reserve(cur_batch_size_); k_ragged_rope_pos_offset_host_.clear(); + current_lora_adapter_ids_host_.clear(); for (int i = 0; i < cur_batch_size_; ++i) { auto it = seq_map_.find(seq_ids[i]); TVM_FFI_ICHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache."; sequences.push_back(&it->second); + current_lora_adapter_ids_host_.push_back(it->second.lora_adapter_id); last_block_length_before_append.push_back( global_block_pool_[it->second.last_block_idx].seq_length); int k_rope_offset = it->second.seq_length; @@ -1195,6 +1203,30 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + void SetSequenceLoraAdapter(int64_t seq_id, int64_t lora_adapter_id) final { + CHECK_GE(lora_adapter_id, 0) << "LoRA adapter id must be non-negative."; + CHECK_LE(lora_adapter_id, std::numeric_limits::max()) + << "LoRA adapter id exceeds int32 range."; + auto it = seq_map_.find(seq_id); + CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; + it->second.lora_adapter_id = static_cast(lora_adapter_id); + } + + int64_t GetSequenceLoraAdapter(int64_t seq_id) final { + auto it = seq_map_.find(seq_id); + CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; + return it->second.lora_adapter_id; + } + + ffi::Shape GetCurrentLoraAdapterIds() final { + std::vector adapter_ids; + adapter_ids.reserve(current_lora_adapter_ids_host_.size()); + for (size_t i = 0; i < current_lora_adapter_ids_host_.size(); ++i) { + adapter_ids.push_back(current_lora_adapter_ids_host_[i]); + } + return ffi::Shape(std::move(adapter_ids)); + } + ffi::Shape DisaggPrepareRecv(int64_t seq_id, int append_length) final { // No CPU to GPU copy is needed. // Essentially we diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py index 8f75804cbba6..a6609438ea50 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py @@ -65,6 +65,9 @@ fbegin_forward = None fend_forward = None fcommit_accepted_token_tree_nodes = None +fset_sequence_lora_adapter = None +fget_sequence_lora_adapter = None +fget_current_lora_adapter_ids = None fattention_with_fuse_qkv = None fis_empty = None fdebug_get_kv = None @@ -88,6 +91,7 @@ def set_global_func(head_dim, dtype): global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes + global fset_sequence_lora_adapter, fget_sequence_lora_adapter, fget_current_lora_adapter_ids global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode global \ @@ -110,6 +114,15 @@ def set_global_func(head_dim, dtype): fcommit_accepted_token_tree_nodes = tvm.get_global_func( "vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes" ) + fset_sequence_lora_adapter = tvm.get_global_func( + "vm.builtin.attention_kv_cache_set_sequence_lora_adapter" + ) + fget_sequence_lora_adapter = tvm.get_global_func( + "vm.builtin.attention_kv_cache_get_sequence_lora_adapter" + ) + fget_current_lora_adapter_ids = tvm.get_global_func( + "vm.builtin.attention_kv_cache_get_current_lora_adapter_ids" + ) fattention_with_fuse_qkv = tvm.get_global_func( "vm.builtin.attention_kv_cache_attention_with_fused_qkv" ) @@ -254,6 +267,27 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v): tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3) +def test_lora_adapter_metadata(kv_cache_and_config): + kv_cache, _, _ = kv_cache_and_config + + fadd_sequence(kv_cache, 1) + fadd_sequence(kv_cache, 2) + assert fget_sequence_lora_adapter(kv_cache, 1) == 0 + assert fget_sequence_lora_adapter(kv_cache, 2) == 0 + + fset_sequence_lora_adapter(kv_cache, 1, 7) + fset_sequence_lora_adapter(kv_cache, 2, 11) + assert fget_sequence_lora_adapter(kv_cache, 1) == 7 + assert fget_sequence_lora_adapter(kv_cache, 2) == 11 + + ffork_sequence(kv_cache, 1, 3, -1) + assert fget_sequence_lora_adapter(kv_cache, 3) == 7 + + fbegin_forward(kv_cache, ShapeTuple([2, 3, 1]), ShapeTuple([1, 2, 1])) + assert list(fget_current_lora_adapter_ids(kv_cache)) == [11, 7, 7] + fend_forward(kv_cache) + + def f_apply_rotary(x, offset, scale, theta, offset_list: list[int] | None = None): # x: (N, H, D) assert len(x.shape) == 3 From 9eb531b38b2ace1e04b26f8975b6bc6129e8f23f Mon Sep 17 00:00:00 2001 From: yash solanki Date: Sat, 7 Mar 2026 21:16:45 +0530 Subject: [PATCH 2/3] runtime: use tvm ffi checks in paged kv cache lora metadata --- src/runtime/vm/paged_kv_cache.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 958047226f13..1ed72fb60e6c 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -690,7 +690,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Create the child sequence with the child block. auto [child_it, inserted] = seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)}); - ICHECK(inserted); + TVM_FFI_ICHECK(inserted); child_it->second.lora_adapter_id = parent_it->second.lora_adapter_id; dirty_aux_data_device_ = true; } @@ -1204,17 +1204,19 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } void SetSequenceLoraAdapter(int64_t seq_id, int64_t lora_adapter_id) final { - CHECK_GE(lora_adapter_id, 0) << "LoRA adapter id must be non-negative."; - CHECK_LE(lora_adapter_id, std::numeric_limits::max()) + TVM_FFI_ICHECK(lora_adapter_id >= 0) << "LoRA adapter id must be non-negative."; + TVM_FFI_ICHECK(lora_adapter_id <= std::numeric_limits::max()) << "LoRA adapter id exceeds int32 range."; auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; it->second.lora_adapter_id = static_cast(lora_adapter_id); } int64_t GetSequenceLoraAdapter(int64_t seq_id) final { auto it = seq_map_.find(seq_id); - CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; + TVM_FFI_ICHECK(it != seq_map_.end()) + << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; return it->second.lora_adapter_id; } From 16eba55a1b3809442780484c30cbf85853506d0d Mon Sep 17 00:00:00 2001 From: yash solanki Date: Mon, 9 Mar 2026 12:07:26 +0530 Subject: [PATCH 3/3] runtime: reuse host memory vector shape conversion --- src/runtime/vm/paged_kv_cache.cc | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 1ed72fb60e6c..ebe0dcdd5674 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -1221,12 +1221,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } ffi::Shape GetCurrentLoraAdapterIds() final { - std::vector adapter_ids; - adapter_ids.reserve(current_lora_adapter_ids_host_.size()); - for (size_t i = 0; i < current_lora_adapter_ids_host_.size(); ++i) { - adapter_ids.push_back(current_lora_adapter_ids_host_[i]); - } - return ffi::Shape(std::move(adapter_ids)); + return current_lora_adapter_ids_host_.as_int_tuple(); } ffi::Shape DisaggPrepareRecv(int64_t seq_id, int append_length) final {