Improve the IVF-PQ Coarse Batch Size Workspace Estimation#1937
Improve the IVF-PQ Coarse Batch Size Workspace Estimation#1937julianmi wants to merge 4 commits intorapidsai:release/26.04from
Conversation
794088a to
9a14aca
Compare
achirkin
left a comment
There was a problem hiding this comment.
Thanks @julianmi for the PR!
I think the current fix most likely helps to avoid the OOM for all reasonable use cases. However, it doesn't address the problem: IVF-PQ search shouldn't ever fail with OOM - it does workspace calculation internally and should scale correctly.
Looking at the IVF-PQ search code, I think we need to make the coarse batch size estimate more precise. Something like this should work (untested):
inline auto get_max_coarse_batch_size(raft::resources const& res,
const search_params& params,
uint32_t n_probes,
uint32_t n_lists,
uint32_t n_queries,
uint32_t dim_ext,
uint32_t rot_dim) -> uint32_t
{
size_t gemm_elem_size;
size_t qc_elem_size;
switch (params.coarse_search_dtype) {
case CUDA_R_32F: gemm_elem_size = 4; qc_elem_size = 4; break;
case CUDA_R_16F: gemm_elem_size = 2; qc_elem_size = 2; break;
case CUDA_R_8I: gemm_elem_size = 1; qc_elem_size = 4; break;
default: RAFT_FAIL("Unexpected coarse_search_dtype (%d)", int(params.coarse_search_dtype));
}
// Persistent allocations that live for the entire search call.
auto persistent_per_query = static_cast<size_t>(dim_ext) * gemm_elem_size
+ static_cast<size_t>(rot_dim) * sizeof(float)
+ static_cast<size_t>(n_probes) * sizeof(uint32_t);
// Transient allocations during coarse search (select_clusters): qc_distances + cluster_dists.
auto transient_per_query =
static_cast<size_t>(n_lists + n_probes) * qc_elem_size;
auto total_per_query = persistent_per_query + transient_per_query;
auto max_per_ws = raft::resource::get_workspace_free_bytes(res) / total_per_query;
return std::max<uint32_t>(
1,
std::min<uint32_t>(max_per_ws,
std::min<uint32_t>(params.max_internal_batch_size, n_queries)));
}
Thank you, this is the much better fix. As discussed offline, we need to keep the |
OpenAI 5M dataset runs OOM when using quantized data (
floattoint8) as show in this example. ThekMinWorkspaceRatiois too small for theint8data such that the IVF-PQ searchmax_internal_batch_sizeis not reduced. DoublingkMinWorkspaceRatiofixes this. There is no performance degradation for the OpenAI 5Mfloatdataset with this change. Let me know if I should test more datasets.CC @tfeher