Skip to content

[BUG] Handle non-zero-indexed random effects labels for intercept_only and intercept_plus_treatment models in Python BART / BCF #251

@andrewherren

Description

@andrewherren

Description

Both BARTModel and BCFModel in Python allow for random effects models to be specified without an explicit user-supplied basis, via the intercept_only and intercept_plus_treatment model_spec arguments, but the predict method assumes that group IDs provided are zero-indexed (because the group IDs are used as indices into an array of sampled random effects indices). The R implementation correctly handles this by deploying the LabelMapper class.

We should implement this same logic in Python.

Reproducing

  1. Generate data with non-zero-indexed random effects group IDs via
import numpy as np
from stochtree import BARTModel

rng = np.random.default_rng()
n = 100
p = 10
X = rng.uniform(0,1,(100,10))
num_groups = 3
group_ids = rng.choice(num_groups, size=n) + 2
random_intercepts = rng.uniform(0,1,num_groups)
rfx_term = random_intercepts[group_ids - 2]
y = X[:,0] + rfx_term + rng.normal(0,1,n)
  1. Sample and predict from a BART model
bart_model = BARTModel() 
bart_model.sample(X_train = X, y_train = y, rfx_group_ids_train = group_ids, 
                  random_effects_params = {'model_spec': 'intercept_only'})
bart_model.predict(X = X, rfx_group_ids = group_ids)

Expected behavior

The .predict() line above should work correctly, as it is run on the same group IDs that were used in sampling

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions