GridSearch for Multi-label classification in Scikit-learn

2024/10/5 1:17:17

I am trying to do GridSearch for best hyper-parameters in every individual one of ten folds cross validation, it worked fine with my previous multi-class classification work, but not the case this time with multi-label work.

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
clf = OneVsRestClassifier(LinearSVC())C_range = 10.0 ** np.arange(-2, 9)
param_grid = dict(estimator__clf__C = C_range)clf = GridSearchCV(clf, param_grid)
clf.fit(X_train, y_train)

I am getting the error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-65-dcf9c1d2e19d> in <module>()6 7 clf = GridSearchCV(clf, param_grid)
----> 8 clf.fit(X_train, y_train)/usr/local/lib/python2.7/site-packages/sklearn/grid_search.pyc in fit(self, X, y)595 596         """
--> 597         return self._fit(X, y, ParameterGrid(self.param_grid))598 599 /usr/local/lib/python2.7/site-packages/sklearn/grid_search.pyc in _fit(self, X, y,   
parameter_iterable)357                                  % (len(y), n_samples))358             y = np.asarray(y)
--> 359         cv = check_cv(cv, X, y, classifier=is_classifier(estimator))360 361         if self.verbose > 0:/usr/local/lib/python2.7/site-packages/sklearn/cross_validation.pyc in _check_cv(cv, X,  
y, classifier, warn_mask)1365             needs_indices = None1366         if classifier:
-> 1367             cv = StratifiedKFold(y, cv, indices=needs_indices)1368         else:1369             if not is_sparse:/usr/local/lib/python2.7/site-packages/sklearn/cross_validation.pyc in __init__(self, 
y, n_folds, indices, shuffle, random_state)427         for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):428             for label, (_, test_split) in zip(unique_labels, per_label_splits):
--> 429                 label_test_folds = test_folds[y == label]430                 # the test split can be too big because we used431                 # KFold(max(c, self.n_folds), self.n_folds) instead ofValueError: boolean index array should have 1 dimension

Which might refer to the dimension or the format of the label indicator.

print X_train.shape, y_train.shape

get:

(147, 1024) (147, 6)

Seems GridSearch implements StratifiedKFold inherently. The problem raises in the stratified K-fold strategy with multi-label problem.

StratifiedKFold(y_train, 10)

gives

ValueError                                Traceback (most recent call last)
<ipython-input-87-884ffeeef781> in <module>()
----> 1 StratifiedKFold(y_train, 10)/usr/local/lib/python2.7/site-packages/sklearn/cross_validation.pyc in __init__(self,   
y, n_folds, indices, shuffle, random_state)427         for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):428             for label, (_, test_split) in zip(unique_labels, per_label_splits):
--> 429                 label_test_folds = test_folds[y == label]430                 # the test split can be too big because we used431                 # KFold(max(c, self.n_folds), self.n_folds) instead ofValueError: boolean index array should have 1 dimension

Current use of conventional K-fold strategy works fine. Is there any method to implement stratified K-fold to multi-label classification?

Answer

Grid search performs stratified cross-validation for classification problems, but for multi-label tasks this is not implemented; in fact, multi-label stratification is an unsolved problem in machine learning. I recently faced the same issue, and all the literature that I could find was a proposed method in this article (the authors of which state that they couldn't find any other attempts at solving this either).

https://en.xdnf.cn/q/70539.html

Related Q&A

Visualize tree in bash, like the output of unix tree

Given input:apple: banana eggplant banana: cantaloupe durian eggplant: fig:I would like to concatenate it into the format:├─ apple │ ├─ banana │ │ ├─ cantaloupe │ │ └─ durian │ └…

pygame.error: Failed loading libmpg123.dll: Attempt to access invalid address

music = pygame.mixer.music.load(not.mp3) pygame.mixer.music.play(loops=-1)when executing the code I got this error: Traceback (most recent call last):File "C:\Users\Admin\AppData\Local\Programs\Py…

Plot Red Channel from 3D Numpy Array

Suppose that we have an RGB image that we have converted it to a Numpy array with the following code:import numpy as np from PIL import Imageimg = Image.open(Peppers.tif) arr = np.array(img) # 256x256x…

How to remove image noise using opencv - python?

I am working with skin images, in recognition of skin blemishes, and due to the presence of noises, mainly by the presence of hairs, this work becomes more complicated.I have an image example in which …

Django groups and permissions

I would like to create 2 groups (Professors, Students). And I would like to restrict students from creating and deleting Courses.views.py:def is_professor(function=None):def _is_professor(u):if user.gr…

How to (properly) use external credentials in an AWS Lambda function?

I have a (extremely basic but perfectly working) AWS lambda function written in Python that however has embedded credentials to connect to: 1) an external web service 2) a DynamoDB table. What the fu…

How to set environment variable TF_Keras = 1 for onnx conversion?

Recently updated to tensorflow 2.0 and am having trouble getting my .h5 models into .onnx . Used to be a very simple procedure but now I am having an issue. When I run the following code:# onnx testing…

Django App Engine: AttributeError: AnonymousUser object has no attribute backend

I am using djangoappengine. When I try create a new user, authenticate that user, and log them in, I get the following error AttributeError: AnonymousUser object has no attribute backend.My code is sim…

python identity dictionary [duplicate]

This question already has answers here:Closed 12 years ago.Possible Duplicate:How to make a python dictionary that returns key for keys missing from the dictionary instead of raising KeyError? I need…

Whats a good library to manipulate Apache2 config files? [closed]

Closed. This question is seeking recommendations for books, tools, software libraries, and more. It does not meet Stack Overflow guidelines. It is not currently accepting answers.We don’t allow questi…