Confusion matrix multilabel-indicator is not supported in scikit learn

Panjeh
1 min readJun 19, 2020

--

Sometimes you encoded the labels to One Hot Version. Like when the labels are digits and you encoded them such as

1 -> [0,1,0,0,0,0,0,0,0,0]

5->[0,0,0,0,0,1,0,0,0,0]

We usually do this by:

from keras.utils import to_categoricaltrain_labels = to_categorical(train_labels)

Issue:

In such cases, when calling confusion matrix on the output of the models and then trying to get the confusion matrix like:

from sklearn.metrics import confusion_matrix
print(confusion_matrix(test_labels,predictions))

you encounter an error like:

Solution:

In order to fix such issue, you need to first decode which is calling argmax on your “test_labels” and “predictions”, and you will get what you expect.

Like:

print(confusion_matrix(test_labels.argmax(axis=1), predictions.argmax(axis=1)))

Why do we do this? Because the confusion matrix takes a vector of labels (not the one-hot encoding).

Here you will see how argmax works:

--

--

Panjeh
Panjeh

Written by Panjeh

Posting about Python and Laravel

Responses (2)