How to get the params from a saved XGBoost model

2024/9/20 21:24:53

I'm trying to train a XGBoost model using the params below:

xgb_params = {'objective': 'binary:logistic','eval_metric': 'auc','lambda': 0.8,'alpha': 0.4,'max_depth': 10,'max_delta_step': 1,'verbose': True
}

Since my input data is too big to be fully loaded into the memory, I adapt the incremental training:

xgb_clf = xgb.train(xgb_params, input_data, num_boost_round=rounds_per_batch,xgb_model=model_path)

The code for prediction is

xgb_clf = xgb.XGBClassifier()
booster = xgb.Booster()
booster.load_model(model_path)
xgb_clf._Booster = booster
raw_probas = xgb_clf.predict_proba(x)

The result seemed good. But when I tried to invoke xgb_clf.get_xgb_params(), I got a param dict in which all params were set to default values.

I can guess that the root cause is when I initialized the model, I didn't pass any params in. So the model was initialized using the default values but when it predicted, it used an internal booster that had been fitted using some pre-defined params.

However, I wonder is there any way that, after I assign a pre-trained booster model to a XGBClassifier, I can see the real params that are used to train the booster, but not those which are used to initialize the classifier.

Answer

You seem to be mixing the sklearn API with the functional API in your code, if you stick to either one you should get the parameters to persist in the pickle. Here's an example using the sklearn API.

import pickle
import numpy as np
import xgboost as xgb
from sklearn.datasets import load_digitsdigits = load_digits(2)
y = digits['target']
X = digits['data']xgb_params = {'objective': 'binary:logistic','reg_lambda': 0.8,'reg_alpha': 0.4,'max_depth': 10,'max_delta_step': 1,
}
clf = xgb.XGBClassifier(**xgb_params)
clf.fit(X, y, eval_metric='auc', verbose=True)pickle.dump(clf, open("xgb_temp.pkl", "wb"))
clf2 = pickle.load(open("xgb_temp.pkl", "rb"))assert np.allclose(clf.predict(X), clf2.predict(X))
print(clf2.get_xgb_params())

which produces

{'base_score': 0.5,'colsample_bylevel': 1,'colsample_bytree': 1,'gamma': 0,'learning_rate': 0.1,'max_delta_step': 1,'max_depth': 10,'min_child_weight': 1,'missing': nan,'n_estimators': 100,'objective': 'binary:logistic','reg_alpha': 0.4,'reg_lambda': 0.8,'scale_pos_weight': 1,'seed': 0,'silent': 1,'subsample': 1}
https://en.xdnf.cn/q/72302.html

Related Q&A

Reverse Label Encoding giving error

I label encoded my categorical data into numerical data using label encoderdata[Resi] = LabelEncoder().fit_transform(data[Resi])But I when I try to find how they are mapped internally usinglist(LabelEn…

how to check if a value exists in a dataframe

hi I am trying to get the column name of a dataframe which contains a specific word,eg: i have a dataframe,NA good employee Not available best employer not required well mana…

Do something every time a module is imported

Is there a way to do something (like print "funkymodule imported" for example) every time a module is imported from any other module? Not only the first time its imported to the runtime or r…

Unit Testing Interfaces in Python

I am currently learning python in preperation for a class over the summer and have gotten started by implementing different types of heaps and priority based data structures.I began to write a unit tes…

Python Pandas average based on condition into new column

I have a pandas dataframe containing the following data:matchID server court speed 1 1 A 100 1 2 D 200 1 3 D 300 1 …

Merging same-indexed rows by taking non-NaNs from all of them in pandas dataframe

I have a sparse dataframe with duplicate indices. How can I merge the same-indexed rows in a way that I keep all the non-NaN data from the conflicting rows?I know that you can achieve something very c…

Approximating cos using the Taylor series

Im using the Taylors series to calculate the cos of a number, with small numbers the function returns accurate results for example cos(5) gives 0.28366218546322663. But with larger numbers it returns i…

How to apply max min boundaries to a value without using conditional statements

Problem:Write a Python function, clip(lo, x, hi) that returns lo if x is less than lo; hi if x is greater than hi; and x otherwise. For this problem, you can assume that lo < hi.Dont use any conditi…

pandas to_json() redundant backslashes

I have a .csv file containing data about movies and Im trying to reformat it as a JSON file to use it in MongoDB. So I loaded that csv file to a pandas DataFrame and then used to_json method to write i…

How can I get the old zip() in Python3?

I migrated from Python 2.7 to Python 3.3 and zip() does not work as expected anymore. Indeed, I read in the doc that it now returns an iterator instead of a list.So, how I am supposed to deal with this…