Skip to content

[FEA] Add Batching to KMeans#1886

Open
tarang-jain wants to merge 153 commits intorapidsai:release/26.04from
tarang-jain:batched-kmeans
Open

[FEA] Add Batching to KMeans#1886
tarang-jain wants to merge 153 commits intorapidsai:release/26.04from
tarang-jain:batched-kmeans

Conversation

@tarang-jain
Copy link
Contributor

@tarang-jain tarang-jain commented Mar 6, 2026

Merge after #1880

This PR adds support for streaming out of core (dataset on host) kmeans clustering. The idea is simple:

Batched accumulation of centroid updates: Data is processed in batches and batch-wise means and cluster counts are accumulated until all the batches i.e., the full dataset pass has completed.
This PR just brings a batch-size parameter to load and compute cluster assignments and (weighted) centroid adjustments on batches of the dataset. The final centroid 'updates' i.e. a single kmeans iteration only completes when all these accumulated sums are averaged once the whole dataset pass has completed.

Copy link
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some questions and a suggestion!

@achirkin
Copy link
Contributor

Thanks for working on this much-needed feature, @tarang-jain ! Could you please add a short paragraph to the PR description telling how the new batching is implemented? This is extremely helpful not only for reviews, but also for future revisions using tools like git blame, because the PR description is copied to the commit message in the main branch history.

Copy link
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tarang-jain! Here are my comments.

Comment on lines +504 to +512
if (score < 1.0) {
// std::stringstream ss;
// ss << "Expected: " << raft::arr2Str(d_labels_ref.data(), 25, "d_labels_ref", stream);
// std::cout << (ss.str().c_str()) << '\n';
// ss.str(std::string());
// ss << "Actual: " << raft::arr2Str(d_labels.data(), 25, "d_labels", stream);
// std::cout << (ss.str().c_str()) << '\n';
// std::cout << "Score = " << score << '\n';
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this code for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a debug print. It is also present in the other kmeans and kmeans_balanced tests

@tarang-jain
Copy link
Contributor Author

@achirkin I updated the PR desc.

@tarang-jain
Copy link
Contributor Author

I have renamed batch_size to streaming_batch_size to make its purpose clearer and differentiate it from the batch_samples and batch_centroids parameters.

Copy link
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work @tarang-jain! LGTM, just minor comments left.

Comment on lines +449 to +453
// Inertia for the last iteration is always computed
if (!params.inertia_check) {
raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream);
raft::resource::sync_stream(handle);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Inertia for the last iteration is always computed
if (!params.inertia_check) {
raft::copy(inertia.data_handle(), clustering_cost.data_handle(), 1, stream);
raft::resource::sync_stream(handle);
}
// Inertia for the last iteration is always computed after KMeans training

Do we truly need this copy here? This is probably dead code since we are computing the inertia unconditionally at the end.

centroids_regular, centroids_batched, rtol=1e-3, atol=1e-3
), f"max diff: {np.max(np.abs(centroids_regular - centroids_batched))}"

print(inertia_regular, inertia_batched)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to keep the print here?

Comment on lines +175 to +176
auto workspace = rmm::device_uvector<char>(
batch_data.extent(0), stream, raft::resource::get_workspace_resource(handle));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called once per batch. We could maybe move this outside of the loop in the caller function?

Comment on lines +214 to +220
void fit(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const float, int> X,
std::optional<raft::host_vector_view<const float, int>> sample_weight,
raft::device_matrix_view<float, int> centroids,
raft::host_scalar_view<float> inertia,
raft::host_scalar_view<int> n_iter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the host fit function are provided for both <T, int> and <T, int64_t>, but the predict and fit_predict are not (only int64_t). Is this something that is expected? Shouldn't we only expose int64_t even for device functions? Are there performance implications?


workspace.resize(n_samples, stream);

raft::linalg::reduce_rows_by_key(const_cast<DataT*>(X.data_handle()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raft::linalg::reduce_rows_by_key(const_cast<DataT*>(X.data_handle()),
raft::linalg::reduce_rows_by_key(X.data_handle(),

The casting is unnecessary here as the first argument is const anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cpp feature request New feature or request non-breaking Introduces a non-breaking change

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

8 participants