Skip to content

IndexError in get_one_hot #5

@NinelK

Description

@NinelK

Hi!

I'm trying to run

idSysF = DPADModel()
args = DPADModel.prepare_args(selectedMethodCode)
idSysF.fit(yTrain.T, Z=zTrain.T, nx=n_factors, n1=n_beh_factors, epochs=epochs, **args)
zTestPredF, yTestPredF, xTestPredF = idSysF.predict(yTest) # Run inference to generate predictions

It works fine on most datasets, but for one of the datasets, it throws:

Traceback (most recent call last):                                                                                                             
  File "/disk/scratch/nkudryas/BAND-torch/scripts/chewie/run_DPAD.py", line 76, in <module>                                                    
    idSysF.fit(yTrain.T, Z=zTrain.T, nx=n_factors, n1=n_beh_factors, epochs=epochs, **args)                                                    
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/DPADModel.py", line 1661, in fit                         
    history1_Cy = model1_Cy.fit(                                                                                                               
                  ^^^^^^^^^^^^^^                                                                                                               
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/RegressionModel.py", line 443, in fit                    
    inputs_val, outputs_val = prep_IO_data(                                                                                                    
                              ^^^^^^^^^^^^^                                                                                                    
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/RegressionModel.py", line 418, in prep_IO_data           
    outputs = get_one_hot(np.array(outputs, dtype=int), self.num_classes)                                                                      
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                      
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/tools/tools.py", line 452, in get_one_hot                
    res = np.eye(nb_classes)[np.array(targets).reshape(-1)]                                                                                    
          ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                    
IndexError: index 5 is out of bounds for axis 0 with size 5 

I added an Assertion to prevent this indexing error in get_one_hot and get an insight into the values that go into it, and it looks like an 'off by one' issue:

  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/RegressionModel.py", line 418, in prep_IO_data           
    outputs = get_one_hot(np.array(outputs, dtype=int), self.num_classes)                                                                      
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                      
  File "/disk/scratch/nkudryas/micromamba/envs/dpad/lib/python3.11/site-packages/DPAD/tools/tools.py", line 453, in get_one_hot                
    assert np.all((targets_array >= 0) & (targets_array < nb_classes)), f"Targets must be integers from 0 to {nb_classes-1}, not {np.unique(tar
gets_array)}"                                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                         
AssertionError: Targets must be integers from 0 to 4, not [0 1 2 3 4 5]

How should I best fix it? I tried increasing num_classes by 1 in line 1150 of DPADModel.py: Cy2_args["num_classes"] = len(YClasses), but it causes some shape mismatch errors.
Why is num_classes < targets ? Should targets be corrected somehow?

Many thanks,
Nina

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions