Skip to content

Feature/generic ensemble#366

Open
skywardfire1 wants to merge 3 commits intosmartcorelib:developmentfrom
skywardfire1:feature/generic_ensemble
Open

Feature/generic ensemble#366
skywardfire1 wants to merge 3 commits intosmartcorelib:developmentfrom
skywardfire1:feature/generic_ensemble

Conversation

@skywardfire1
Copy link
Copy Markdown
Contributor

this PR implements:

The Generic Ensemble Subsystem, how I name it

"What and why, mr Anderson?"

It allows a user to build his own custom ensemble models. More than that, since I used box dyn predictor, a user can even combine models of a different kind!
In my project I only use 18 kNNs, still I have been only interested in creation of universal ensemble, so... Again here is my attempt.

And here we come to two limitations

  • We do not have the clone ability in this ensemble model, at least for now.
  • It is unable at this point to add NaiveBayes and SVC to the ensemble. The reasons are different, and each one can be worked out, I believe. Still, the whole thing doesn't look as bright as planned...

Allright, with that being said...

🔑 Key Features

  • 🔄 Allows creation of heterogeneous predictor ensembles: mix KNN, Random Forest, Decision Tree are now the only supported models.
    (Almost) any type implementing Predictor<X, Y> can be a member.

  • ⚖️ Two voting strategies, uniform or weighted: simple majority or confidence-based aggregation.
    Switch strategies at runtime with set_voting_strategy(); weights are validated on insertion.

  • 🎛️ Dynamic enable/disable of members at runtime: toggle models without retraining.
    Useful for A/B testing, fallback logic, or excluding underperforming models on-the-fly. My own idea!

  • 🏷️ Metadata: descriptions, tags...: document and organize your ensemble.
    Attach human-readable notes, group models by tags, I have no idea if anyone will use it, but implementing this was too fun and easy

  • ⚖️ Set weights at anythime
    Adjust voting influence with set_weight()

  • ✂️ Feature slicing via predict_using_names(): different inputs per model.
    Train models on disjoint feature subsets and combine predictions — ideal for multi-view learning.
    Again, it was cruicial in my project, that's why I threw it into Smartcore, but unsure whether it is really useful

  • 📊 Built-in scoring: quick accuracy evaluation with score().
    Equivalent to accuracy(y, predict(x)) — and just for being more sklearn-ish

Documentation

📦 Model Management

  • 🔄 Heterogeneous ensembles: Mix KNN, Random Forest, Decision Tree, SVM, or any custom model implementing Predictor<X, Y>.
    No common base class required — trait-based composition.

  • 🎯 Three ways to add models (3 public methods total for model management):

    Method Use Case Auto-name Custom Name Weight Description Tags
    add(model) Quick start ❌ (Uniform only)
    add_named(name, model) Named models for debugging ❌ (Uniform only)
    add_with_params(name?, model, weight?, desc?, tags?) Full control
  • 🏷️ Rich metadata: Attach descriptions, tags, and voting weights to each member. Query voting weight via weight(name).

    ℹ️ description and tags are stored internally but not exposed via public getters yet (reserved for future API).

  • ⚙️ Dynamic runtime control: Enable/disable individual models without retraining via enable(), disable(), enabled(). Perfect for A/B testing, fallback logic, or excluding underperformers on-the-fly.

🗳️ Voting Strategies

  • ⚖️ Uniform or Weighted voting: Simple majority or confidence-based aggregation. Switch at runtime with set_voting_strategy().

  • 🛡️ Rust-style strictness in Weighted mode:

    "Explicit is better than implicit."
    When using VotingStrategy::Weighted, every member must have an explicit, finite, non-negative weight. The API will fail fast with a clear error if you try to add a model without a weight — no silent defaults, no hidden magic. This prevents subtle bugs and makes ensemble behavior predictable.

  • 🔧 Weight management: Set or update weights anytime via set_weight(). Weights are validated on insertion and on strategy switch.

🔮 Prediction & Evaluation

  • 🎯 Two prediction modes (+ 1 scoring helper):
Method Input Use Case
predict(&x) Single X for all models Standard ensemble: all models see the same features
predict_using_names(&HashMap<String, X>) Per-model X via name Feature slicing: each model gets its own feature subset
score(&x, &y) -> f64 Single X + labels Y Quick accuracy evaluation (sklearn-style convenience)
  • 📊 Built-in scoring: score() returns accuracy in [0.0, 1.0] — equivalent to accuracy(y, predict(x)), but convenient for cross-validation loops and hyperparameter tuning.

  • ✅ Type-safe predictions: All models in an ensemble must share the same X: Array2<f64> and Y: Array1<i32> + Clone, enforced at compile time via generics + PhantomData.

🧰 Introspection & Utilities

  • 🔍 Ensemble state: names(), len(), is_empty(), strategy(), get_ensemble_info() — query structure and configuration anytime.

  • 🏷️ Metadata queries: weight(name) — get voting weight for a member.

  • 🔄 Strategy switching: set_voting_strategy() validates all weights when switching to Weighted, ensuring consistency.


📚 Usage Guide: From Simple to Advanced

🎯 Scenario 1: The "Just Works" Way (3 lines)

Wanna go on ease? No problem! Just do:

use smartcore::ensemble::general_ensemble::Ensemble;
use smartcore::neighbors::knn_classifier::KNNClassifier;

// 1. Create ensemble (defaults to Uniform voting)
let mut ensemble = Ensemble::new();

// 2. Train and add models — names auto-generated ("model_0", "model_1", ...)
let knn1 = KNNClassifier::fit(&x_train, &y_train, params_k3)?;
let knn2 = KNNClassifier::fit(&x_train, &y_train, params_k5)?;
ensemble.add(knn1)?;
ensemble.add(knn2)?;

// 3. Predict — uniform voting out of the box
let predictions = ensemble.predict(&x_valid)?;
let acc = ensemble.score(&x_valid, &y_valid)?;
println!("Accuracy: {:.4}", acc);

✅ That's it. No weights, no names, no config.


🎯 Scenario 2: Name Your Models

Use add_named() when you want explicit control and better observability in your ensemble. Meaningful names make it easier to:

  • Debug individual model behavior
  • Enable/disable specific members at runtime
  • Log and audit which models contributed to a prediction
  • Manage A/B tests or canary deployments
let mut ensemble = Ensemble::new();

let knn_small = KNNClassifier::fit(&x_train, &y_train, k=3)?;
let knn_large = KNNClassifier::fit(&x_train, &y_train, k=15)?;

// Give them meaningful names — easier to debug and manage
ensemble.add_named("knn_k3".into(), knn_small)?;
ensemble.add_named("knn_k15".into(), knn_large)?;

// Later: inspect, enable/disable, debug by name
println!("Models: {:?}", ensemble.names());  // ["knn_k3", "knn_k15"]
ensemble.disable("knn_k3")?;                  // temporarily exclude from voting
let active = ensemble.enabled();              // ["knn_k15"]

💡 add_named() is syntactic sugar over add_with_params(Some(name), model, None, None, vec![]).


🎯 Scenario 3: Control Voting — Full Lifecycle

Step-by-step: Uniform → assign weights → switch to Weighted

// Step 1: Start with Uniform (default)
let mut ensemble = Ensemble::new();

let model_a = train_model_a()?;
let model_b = train_model_b()?;

// Add models without weights — works in Uniform mode
ensemble.add_named("model_a".into(), model_a)?;
ensemble.add_named("model_b".into(), model_b)?;

// Predict with simple majority voting
let preds_uniform = ensemble.predict(&x_valid)?;

// Step 2: Assign weights to models (prepare for Weighted mode)
ensemble.set_weight("model_a", 1.0)?;  // baseline
ensemble.set_weight("model_b", 2.5)?;  // higher confidence

// Step 3: Switch to Weighted strategy
// ⚠️ This will fail if any enabled member lacks a weight
ensemble.set_voting_strategy(VotingStrategy::Weighted)?;

// Now predictions use weighted voting
let preds_weighted = ensemble.predict(&x_valid)?;

// Compare results
println!("Uniform acc: {:.4}", ensemble.score(&x_valid, &y_valid)?);
ensemble.set_voting_strategy(VotingStrategy::Weighted)?;
println!("Weighted acc: {:.4}", ensemble.score(&x_valid, &y_valid)?);

🔁 Tip: You can switch strategies multiple times at runtime. Just ensure all members have valid weights before activating Weighted mode.


🎯 Scenario 4: Feature Slicing — Different Inputs per Model

For advanced use-cases like training on different feature subsets (multi-view learning).

let mut ensemble = Ensemble::with_strategy(VotingStrategy::Uniform);

// Train models on different feature slices
let model_a = train_on_features(&x_train, &[0,1,2])?;  // features 0-2
let model_b = train_on_features(&x_train, &[3,4,5])?;  // features 3-5

ensemble.add_named("slice_A".into(), model_a)?;
ensemble.add_named("slice_B".into(), model_b)?;

// Prepare inputs: each model gets its own feature slice
let mut inputs = HashMap::new();
inputs.insert("slice_A".into(), extract_features(&x_valid, &[0,1,2])?);
inputs.insert("slice_B".into(), extract_features(&x_valid, &[3,4,5])?);

// Predict with per-model inputs
let predictions = ensemble.predict_using_names(&inputs)?;

✂️ Constraint: All input matrices must have the same number of samples (rows), but can have different numbers of features (columns).


🎯 Scenario 5: Full Control — Metadata, Tags, Dynamic Management

let mut ensemble = Ensemble::with_strategy(VotingStrategy::Weighted);

// Add with full metadata
ensemble.add_with_params(
    Some("rf_prod_v2".into()),
    rf_model,
    Some(1.5),
    Some("Random Forest, depth=20, trained on Q1 data".into()),
    vec!["tree".into(), "production".into(), "v2".into()]
)?;

// Query metadata
assert_eq!(ensemble.weight("rf_prod_v2"), Some(1.5));
// ℹ️ description()/tags() getters are planned for future release

// Dynamic control at runtime
ensemble.disable("rf_prod_v2")?;   // exclude from voting
ensemble.enable("rf_prod_v2")?;    // re-include
ensemble.set_weight("rf_prod_v2", 2.0)?; // adjust influence

// Introspect state
let info = ensemble.get_ensemble_info();
println!("Strategy: {:?}, Members: {}/{}", info.strategy, info.enabled_members, info.total_members);

🏭 Real-World Usage Patterns

Those are of my SAAN project.

Pattern 1: Auto-disable underperforming models

// Disable models with quality score below threshold
let threshold = 0.4;
for model_name in ensemble.names() {
    // Note: weight() here represents model quality score, not voting weight
    if ensemble.weight(&model_name) < Some(threshold) {
        ensemble.disable(&model_name)?;
        debug!("Disabled underperforming model: {}", model_name);
    }
}

Pattern 2: Compare voting strategies on the same ensemble

// Evaluate Weighted voting
let predictions_weighted = ensemble.predict_using_names(&valid_x_combined)?;
let (acc_w, prec_w, rec_w, f1_w) = count_metrics!(&y_valid, &predictions_weighted);

// Switch to Uniform and re-evaluate
ensemble.set_voting_strategy(VotingStrategy::Uniform)?;
let predictions_uniform = ensemble.predict_using_names(&valid_x_combined)?;
let (acc_u, prec_u, rec_u, f1_u) = count_metrics!(&y_valid, &predictions_uniform);

println!("Weighted — Acc: {:.4}, F1: {:.4}", acc_w, f1_w);
println!("Uniform  — Acc: {:.4}, F1: {:.4}", acc_u, f1_u);

Pattern 3: Dynamically add a strong model and boost its influence

// Add Random Forest as a new ensemble member
let rf_params = RandomForestClassifierParameters::default()
    .with_n_trees(20)
    .with_max_depth(u16::MAX as u16);
    
let rf_model = RandomForestClassifier::fit(&x_train, &y_train, rf_params)?;
ensemble.add_with_params(
    Some("random_forest".into()), 
    rf_model, 
    Some(1.0),  // initial weight
    Some("Random Forest, 20 trees".into()), 
    vec![]
)?;

// Include RF in the input map for feature-sliced prediction
valid_x_combined.insert("random_forest".into(), x_valid.clone());

// Predict with Uniform voting first
ensemble.set_voting_strategy(VotingStrategy::Uniform)?;
let preds_uniform_rf = ensemble.predict_using_names(&valid_x_combined)?;

// Then boost RF's influence in Weighted mode
ensemble.set_weight("random_forest", 0.9)?;
ensemble.set_voting_strategy(VotingStrategy::Weighted)?;
let preds_weighted_rf = ensemble.predict_using_names(&valid_x_combined)?;

📊 Interpreting Ensemble Logs

When running ensembles in production, you'll see structured output like this:

Ensemble: Strategy=Uniform, Active=19/19 members
After pruning (precision < 0.4): Strategy=Weighted, Active=13/19 members

=== Fold 4 Results ===
Baseline (mean):     Acc: 0.480, Prec: 0.467, Rec: 0.420, F1: 0.417
Uniform ensemble:    Acc: 0.697, Prec: 0.804, Rec: 0.606, F1: 0.667
Weighted ensemble:   Acc: 0.742, Prec: 0.792, Rec: 0.668, F1: 0.703  ← +5.3% F1
+RF, Uniform:        Acc: 0.727, Prec: 0.805, Rec: 0.658, F1: 0.705
+RF, Weighted:       Acc: 0.818, Prec: 0.855, Rec: 0.765, F1: 0.794  ← +12.6% F1 vs baseline

🔍 How to read this:

  1. Member count: Active=13/19 means 6 models were disabled due to low precision
  2. Strategy impact: Weighted voting improved F1 by 5.3% over Uniform on the same models
  3. Model addition: Adding Random Forest boosted performance further
  4. Weight tuning: Giving RF higher weight (0.9) in Weighted mode yielded the best result

💡 Pro tips:

  • Always log get_ensemble_info() before/after major changes
  • Compare metrics across strategies to validate your weighting scheme
  • Use enabled() to verify which models actually contributed to a prediction

🧪 Testing Philosophy

Our test suite covers:

Test Category Examples
✅ Basic functionality add(), add_named(), auto-names
✅ Heterogeneous ensembles KNN + RF + Decision Tree in one ensemble
✅ Voting strategies Uniform vs Weighted, weight validation
✅ Feature slicing predict_using_names() with per-model inputs
✅ Runtime management enable()/disable() affecting predictions
✅ Metadata Weights query/update
✅ Error handling Duplicate names, missing weights, empty ensemble
✅ Scoring score() validity across model additions/removals

📊 Test coverage: 13 focused tests covering basic usage, error paths, voting strategies, feature slicing, and runtime management.

All tests use minimal, reproducible dummy data and verify both success and failure paths.


📋 Public API Summary

Category Methods Count
Construction new(), with_strategy() 2
Add models add(), add_named(), add_with_params() 3
Metadata set_weight(), set_description(), weight() 3
Runtime control enable(), disable(), enabled() 3
Prediction predict(), predict_using_names(), score() 3
Introspection names(), len(), is_empty(), strategy(), get_ensemble_info() 5
Strategy set_voting_strategy() 1
Total public methods 20

🎯 Of these, 3 methods add models, 3 methods predict, and 1 method scores — the core workflow in 7 calls.


⚠️ Common Pitfalls & How We Prevent Them

Pitfall Our Solution
Forgetting weights in Weighted mode ❌ Compile-time + runtime validation; clear error message
Duplicate model names add_with_params() checks HashMap keys; fails fast
Mismatched input dimensions in predict_using_names() Array2 trait enforces shape; Failed error on mismatch
Using disabled models in voting enabled() filter applied automatically in predict()
Type mismatches between models ✅ Generics Ensemble<X, Y> enforce same input/output types at compile time

🚀 What's Next? (Roadmap)

Feature Status Priority
predict_proba() support 🟡 Planned High
UUID names ⚪ Idea Medium
Public getters for description() and tags() ⚪ Idea Medium
Auto-reset weights to None when switching Weighted → Uniform ⚪ Idea Low

ℹ️ Note on strategy switching: When switching from Weighted to Uniform, weights are preserved but ignored. If you want a "clean slate", manually set weights to None via a future helper method (planned).

Ready for review. 🦀✨

@skywardfire1 skywardfire1 requested a review from Mec-iS as a code owner April 3, 2026 10:18
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 3, 2026

Codecov Report

❌ Patch coverage is 29.19255% with 114 lines in your changes missing coverage. Please review.
✅ Project coverage is 44.24%. Comparing base (70d8a0f) to head (469e9ee).
⚠️ Report is 14 commits behind head on development.

Files with missing lines Patch % Lines
src/ensemble/generic_ensemble.rs 29.19% 114 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff               @@
##           development     #366      +/-   ##
===============================================
- Coverage        45.59%   44.24%   -1.35%     
===============================================
  Files               93       96       +3     
  Lines             8034     8190     +156     
===============================================
- Hits              3663     3624      -39     
- Misses            4371     4566     +195     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@skywardfire1
Copy link
Copy Markdown
Contributor Author

Around 10 days of work. I'll fix 2 failing builds soon.

@Mec-iS
Copy link
Copy Markdown
Collaborator

Mec-iS commented Apr 3, 2026

wow this is great! thanks.
it will take some time to unroll.

it would be nice to have also #365 fixed with this so we can bump to v0.5.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants