XGBoost and sparse matrix

2024/10/13 3:30:09

I am trying to use xgboost to run -using python - on a classification problem, where I have the data in a numpy matrix X (rows = observations & columns = features) and the labels in a numpy array y. Because my data are sparse, I would like to make it run using a sparse version of X, but it seems I am missing something as an error occurs.

Here is what I do :

# Library importimport numpy as np
import xgboost as xgb
from xgboost.sklearn import XGBClassifier
from scipy.sparse import csr_matrix# Converting to sparse data and running xgboostX_csr = csr_matrix(X)
xgb1 = XGBClassifier()
xgtrain = xgb.DMatrix(X_csr, label = y )      #to work with the xgb format
xgtest = xgb.DMatrix(Xtest_csr)
xgb1.fit(xgtrain, y, eval_metric='auc')
dtrain_predictions = xgb1.predict(xgtest)   

etc...

Now I get an error when trying to fit the classifier :

File ".../xgboost/python-package/xgboost/sklearn.py", line 432, in fit
self._features_count = X.shape[1]AttributeError: 'DMatrix' object has no attribute 'shape'

Now, I looked for a while on where it could come from, and believe it has to do with the sparse format I wish to use. But what it is, and how I could fix it, I have no clue.

I would welcome any help or comments ! Thank you very much

Answer

You are using the xgboost scikit-learn API (http://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn), so you don't need to convert your data to a DMatrix to fit the XGBClassifier(). Just removing the line

xgtrain = xgb.DMatrix(X_csr, label = y )

should work:

type(X_csr) #scipy.sparse.csr.csr_matrix
type(y) #numpy.ndarray
xgb1 = xgb.XGBClassifier()
xgb1.fit(X_csr, y)

which outputs:

XGBClassifier(base_score=0.5, colsample_bylevel=1, colsample_bytree=1,gamma=0, learning_rate=0.1, max_delta_step=0, max_depth=3,min_child_weight=1, missing=None, n_estimators=100, nthread=-1,objective='binary:logistic', reg_alpha=0, reg_lambda=1,scale_pos_weight=1, seed=0, silent=True, subsample=1)
https://en.xdnf.cn/q/69578.html

Related Q&A

How to preserve form fields in django after unsuccessful submit?

Code from views.py:def feedback(request):if request.method == "POST":form = CommentForm(request.POST)if form.is_valid():form.save()else:print("form.errors:", form.errors)else:form =…

Idiomatic way to parse POSIX timestamps in pandas?

I have a csv file with a time column representing POSIX timestamps in milliseconds. When I read it in pandas, it correctly reads it as Int64 but I would like to convert it to a DatetimeIndex. Right now…

Apply function on each column in a pandas dataframe

How I can write following function in more pandas way:def calculate_df_columns_mean(self, df):means = {}for column in df.columns.columns.tolist():cleaned_data = self.remove_outliers(df[column].tolist()…

Compare 2 consecutive rows and assign increasing value if different (using Pandas)

I have a dataframe df_in like so:import pandas as pd dic_in = {A:[aa,aa,bb,cc,cc,cc,cc,dd,dd,dd,ee],B:[200,200,200,400,400,500,700,700,900,900,200],C:[da,cs,fr,fs,se,at,yu,j5,31,ds,sz]} df_in = pd.Data…

searching for k nearest points

I have a large set of features that looks like this:id1 28273 20866 29961 27190 31790 19714 8643 14482 5384 .... upto 1000 id2 12343 45634 29961 27130 33790 14714 7633 15483 4484 .... id3 ..... ....…

Why does del (x) with parentheses around the variable name work?

Why does this piece of code work the way it does?x = 3 print(dir()) #output indicates that x is defined in the global scope del (x) print(dir()) #output indicates that x is not defined in the glob…

How to concisely represent if/else to specify CSS classes in Django templates

In a Django template, Id like to add CSS classes to a DIV based on certain "conditions", for example:<div class="pkg-buildinfo {% if v.release.version == pkg.b.release.version %}activ…

LabelEncoder: How to keep a dictionary that shows original and converted variable

When using LabelEncoder to encode categorical variables into numerics, how does one keep a dictionary in which the transformation is tracked?i.e. a dictionary in which I can see which values became wh…

How to find hidden files inside image files (Jpg/Gif/Png) [closed]

As it currently stands, this question is not a good fit for our Q&A format. We expect answers to be supported by facts, references, or expertise, but this question will likely solicit debate, argum…

How to open a simple image using streams in Pillow-Python

from PIL import Imageimage = Image.open("image.jpg")file_path = io.BytesIO();image.save(file_path,JPEG);image2 = Image.open(file_path.getvalue());I get this error TypeError: embedded NUL char…