Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/runtime/vm/attn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ struct Sequence {
* in the KV cache even when sliding window is enabled.
*/
int last_block_attn_sink_size = 0;
/*! \brief The LoRA adapter id associated with the sequence. */
int32_t lora_adapter_id = 0;

/*! \brief Whether the current appended tokens form a chain (not a tree). */
bool is_chain = true;
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
&AttentionKVCacheObj::EnableSlidingWindowForSeq)
.def_method("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes",
&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes)
.def_method("vm.builtin.attention_kv_cache_set_sequence_lora_adapter",
&AttentionKVCacheObj::SetSequenceLoraAdapter)
.def_method("vm.builtin.attention_kv_cache_get_sequence_lora_adapter",
&AttentionKVCacheObj::GetSequenceLoraAdapter)
.def_method("vm.builtin.attention_kv_cache_get_current_lora_adapter_ids",
&AttentionKVCacheObj::GetCurrentLoraAdapterIds)
.def_method("vm.builtin.attention_kv_cache_empty", &AttentionKVCacheObj::Empty)
.def_method("vm.builtin.attention_kv_cache_get_num_available_pages",
&AttentionKVCacheObj::GetNumAvailablePages)
Expand Down
20 changes: 20 additions & 0 deletions src/runtime/vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
const IntTuple& leaf_indices) = 0;

/*!
* \brief Set the LoRA adapter id for the specified sequence.
* \param seq_id The id of the sequence to update.
* \param lora_adapter_id The LoRA adapter id to associate with the sequence.
*/
virtual void SetSequenceLoraAdapter(int64_t seq_id, int64_t lora_adapter_id) = 0;

/*!
* \brief Get the LoRA adapter id for the specified sequence.
* \param seq_id The id of the sequence to query.
* \return The LoRA adapter id associated with the sequence.
*/
virtual int64_t GetSequenceLoraAdapter(int64_t seq_id) = 0;

/*!
* \brief Get the LoRA adapter ids of the current batch specified in BeginForward.
* \return The LoRA adapter ids, in the same order as the current batch sequence ids.
*/
virtual ffi::Shape GetCurrentLoraAdapterIds() = 0;

/*! \brief Prepare for the disaggregation KV data receive for the specified sequence and length.*/
virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0;

Expand Down
31 changes: 30 additions & 1 deletion src/runtime/vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
std::vector<HostMemoryVector> k_rope_pos_offset_on_depths_host_;
std::vector<HostMemoryVector> k_rope_pos_offset_sliding_window_on_depths_host_;
HostMemoryVector k_ragged_rope_pos_offset_host_;
HostMemoryVector current_lora_adapter_ids_host_;
HostMemoryVector q_rope_position_map_host_;
HostMemoryVector append_position_map_host_;
HostMemoryVector cur_append_lengths_indptr_host_;
Expand Down Expand Up @@ -414,6 +415,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
k_ragged_rope_pos_offset_host_ =
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device);
current_lora_adapter_ids_host_ =
HostMemoryVector(reserved_num_seqs, dtype_aux_, preferred_host_device);
q_rope_position_map_host_ =
HostMemoryVector(prefill_chunk_size, dtype_aux_, preferred_host_device);
append_position_map_host_ =
Expand Down Expand Up @@ -685,7 +688,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
break;
}
// Create the child sequence with the child block.
seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)});
auto [child_it, inserted] =
seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)});
TVM_FFI_ICHECK(inserted);
child_it->second.lora_adapter_id = parent_it->second.lora_adapter_id;
dirty_aux_data_device_ = true;
}

Expand Down Expand Up @@ -874,11 +880,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
sequences.reserve(cur_batch_size_);
last_block_length_before_append.reserve(cur_batch_size_);
k_ragged_rope_pos_offset_host_.clear();
current_lora_adapter_ids_host_.clear();
for (int i = 0; i < cur_batch_size_; ++i) {
auto it = seq_map_.find(seq_ids[i]);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_ids[i] << "\" cannot be found in KV cache.";
sequences.push_back(&it->second);
current_lora_adapter_ids_host_.push_back(it->second.lora_adapter_id);
last_block_length_before_append.push_back(
global_block_pool_[it->second.last_block_idx].seq_length);
int k_rope_offset = it->second.seq_length;
Expand Down Expand Up @@ -1195,6 +1203,27 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

void SetSequenceLoraAdapter(int64_t seq_id, int64_t lora_adapter_id) final {
TVM_FFI_ICHECK(lora_adapter_id >= 0) << "LoRA adapter id must be non-negative.";
TVM_FFI_ICHECK(lora_adapter_id <= std::numeric_limits<int32_t>::max())
<< "LoRA adapter id exceeds int32 range.";
auto it = seq_map_.find(seq_id);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
it->second.lora_adapter_id = static_cast<int32_t>(lora_adapter_id);
}

int64_t GetSequenceLoraAdapter(int64_t seq_id) final {
auto it = seq_map_.find(seq_id);
TVM_FFI_ICHECK(it != seq_map_.end())
<< "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
return it->second.lora_adapter_id;
}

ffi::Shape GetCurrentLoraAdapterIds() final {
return current_lora_adapter_ids_host_.as_int_tuple();
}

ffi::Shape DisaggPrepareRecv(int64_t seq_id, int append_length) final {
// No CPU to GPU copy is needed.
// Essentially we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
fbegin_forward = None
fend_forward = None
fcommit_accepted_token_tree_nodes = None
fset_sequence_lora_adapter = None
fget_sequence_lora_adapter = None
fget_current_lora_adapter_ids = None
fattention_with_fuse_qkv = None
fis_empty = None
fdebug_get_kv = None
Expand All @@ -88,6 +91,7 @@
def set_global_func(head_dim, dtype):
global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq
global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes
global fset_sequence_lora_adapter, fget_sequence_lora_adapter, fget_current_lora_adapter_ids
global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv
global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode
global \
Expand All @@ -110,6 +114,15 @@ def set_global_func(head_dim, dtype):
fcommit_accepted_token_tree_nodes = tvm.get_global_func(
"vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes"
)
fset_sequence_lora_adapter = tvm.get_global_func(
"vm.builtin.attention_kv_cache_set_sequence_lora_adapter"
)
fget_sequence_lora_adapter = tvm.get_global_func(
"vm.builtin.attention_kv_cache_get_sequence_lora_adapter"
)
fget_current_lora_adapter_ids = tvm.get_global_func(
"vm.builtin.attention_kv_cache_get_current_lora_adapter_ids"
)
fattention_with_fuse_qkv = tvm.get_global_func(
"vm.builtin.attention_kv_cache_attention_with_fused_qkv"
)
Expand Down Expand Up @@ -254,6 +267,27 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, expected_v):
tvm.testing.assert_allclose(values.numpy(), values_expected, rtol=1e-3, atol=1e-3)


def test_lora_adapter_metadata(kv_cache_and_config):
kv_cache, _, _ = kv_cache_and_config

fadd_sequence(kv_cache, 1)
fadd_sequence(kv_cache, 2)
assert fget_sequence_lora_adapter(kv_cache, 1) == 0
assert fget_sequence_lora_adapter(kv_cache, 2) == 0

fset_sequence_lora_adapter(kv_cache, 1, 7)
fset_sequence_lora_adapter(kv_cache, 2, 11)
assert fget_sequence_lora_adapter(kv_cache, 1) == 7
assert fget_sequence_lora_adapter(kv_cache, 2) == 11

ffork_sequence(kv_cache, 1, 3, -1)
assert fget_sequence_lora_adapter(kv_cache, 3) == 7

fbegin_forward(kv_cache, ShapeTuple([2, 3, 1]), ShapeTuple([1, 2, 1]))
assert list(fget_current_lora_adapter_ids(kv_cache)) == [11, 7, 7]
fend_forward(kv_cache)


def f_apply_rotary(x, offset, scale, theta, offset_list: list[int] | None = None):
# x: (N, H, D)
assert len(x.shape) == 3
Expand Down