Confusion matrix multilabel-indicator is not supported in scikit learn
Sometimes you encoded the labels to One Hot Version. Like when the labels are digits and you encoded them such as
1 -> [0,1,0,0,0,0,0,0,0,0]
5->[0,0,0,0,0,1,0,0,0,0]
We usually do this by:
from keras.utils import to_categoricaltrain_labels = to_categorical(train_labels)
Issue:
In such cases, when calling confusion matrix on the output of the models and then trying to get the confusion matrix like:
from sklearn.metrics import confusion_matrix
print(confusion_matrix(test_labels,predictions))
you encounter an error like:
Solution:
In order to fix such issue, you need to first decode which is calling argmax
on your “test_labels” and “predictions”, and you will get what you expect.
Like:
print(confusion_matrix(test_labels.argmax(axis=1), predictions.argmax(axis=1)))
Why do we do this? Because the confusion matrix takes a vector of labels (not the one-hot encoding).
Here you will see how argmax works: