Can I safely assign to `coef_` and other estimated parameters in scikit-learn?

2024/9/8 12:01:41

scikit-learn suggests the use of pickle for model persistence. However, they note the limitations of pickle when it comes to different version of scikit-learn or python. (See also this stackoverflow question)

In many machine learning approaches, only few parameters are learned from large data sets. These estimated parameters are stored in attributes with trailing underscore, e.g. coef_

Now my question is the following: Can model persistence be achieved by persisting the estimated attributes and assigning to them later? Is this approach safe for all estimators in scikit-learn, or are there potential side-effects (e.g. private variables that have to be set) in the case of some estimators?

It seems to work for logistic regression, as seen in the following example:

from sklearn import datasets
from sklearn.linear_model import LogisticRegression
try:from sklearn.model_selection import train_test_split
except ImportError:from sklearn.cross_validation import train_test_split
iris = datasets.load_iris()
tt_split = train_test_split(iris.data, iris.target, test_size=0.4)
X_train, X_test, y_train, y_test = tt_split# Here we train the logistic regression
lr = LogisticRegression(class_weight='balanced')
lr.fit(X_train, y_train)
print(lr.score(X_test, y_test))     # prints 0.95# Persisting
params = lr.get_params()
coef = lr.coef_
intercept = lr.intercept_
# classes_ is not documented as public member, 
# but not explicitely private (not starting with underscore)
classes = lr.classes_ 
lr.n_iter_ #This is meta-data. No need to persist# Now we try to load the Classifier 
lr2 = LogisticRegression()
lr2.set_params(**params)
lr2.coef_ = coef
lr2.intercept_ = intercept
lr2.classes_ = classes
print(lr2.score(X_test, y_test)) #Prints the same: 0.95
Answer

Setting the estimated attributes alone is not enough - at least in the general case for all estimators.

I know of at least one example where this would fail.LinearDiscriminantAnalysis.transform() makes use of the private attribute _max_components:

def transform(self, X):# ... code omittedreturn X_new[:, :self._max_components]

However, it might work for some estimators. If you only need this for a specific estimator the best approach would be to look at the estimators source code and save all attributes that are set in the __init__() and .fit() methods.

A more generic approach could be to save all items in the estimator's .__dict__. E.g.:

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
lda = LDA().fit([[1, 2, 3], [1, 2, 1], [4, 5, 6], [9, 9, 9]], [1, 2, 1, 2])
lda.__dict__
# {'_max_components': 1,
#  'classes_': array([1, 2]),
#  'coef_': array([[ -9.55555556,  21.55555556,  -9.55555556]]),
#  'explained_variance_ratio_': array([ 1.]),
#  'intercept_': array([-15.77777778]),
#  'means_': array([[ 2.5,  3.5,  4.5],
#         [ 5. ,  5.5,  5. ]]),
#  'n_components': None,
#  'priors': None,
#  'priors_': array([ 0.5,  0.5]),
#  'scalings_': array([[-2.51423299],
#         [ 5.67164186],
#         [-2.51423299]]),
#  'shrinkage': None,
#  'solver': 'svd',
#  'store_covariance': False,
#  'tol': 0.0001,
#  'xbar_': array([ 3.75,  4.5 ,  4.75])}

This won't be trivial for estimators that contain more complex data, such as ensembles that contain multiple estimators. See the blog post Scikit-learn Pipeline Persistence and JSON Serialization for more details.

Unfortunately, this will not safely carry estimators over to new versions of scikit-learn. Private attributes are essentially an implementation detail that could change anytime between releases.

https://en.xdnf.cn/q/72650.html

Related Q&A

How to update the filename of a Djangos FileField instance?

Here a simple django model:class SomeModel(models.Model):title = models.CharField(max_length=100)video = models.FileField(upload_to=video)I would like to save any instance so that the videos file name …

CSS Templating system for Django / Python?

Im wondering if there is anything like Djangos HTML templating system, for for CSS.. my searches on this arent turning up anything of use. I am aware of things like SASS and CleverCSS but, as far as I …

How to use chomedriver with a proxy for selenium webdriver?

Our network environment using a proxy server to connect to the outside internet, configured in IE => Internet Options => Connections => LAN Settings, like "10.212.20.11:8080".Now, Im…

django application static files not working in production

static files are not working in production even for the admin page.I have not added any static files.I am having issues with my admin page style.I have followed the below tutorial to create the django …

Celery task in Flask for uploading and resizing images and storing it to Amazon S3

Im trying to create a celery task for uploading and resizing an image before storing it to Amazon S3. But it doesnt work as expected. Without the task everything is working fine. This is the code so fa…

Python logger formatting is not formatting the string

Following are the contents of mylogger.py:def get_logger(name=my_super_logger):log = logging.getLogger(name)log.setLevel(logging.DEBUG)formatter = logging.Formatter(fmt=%(asctime)s %(name)s %(message)s…

Python subprocess.Popen.wait() returns 0 even though an error occured

I am running a command line utility via Pythons subprocess module. I create a subprocess.Popen() object with the command line arguments and stdout=subprocess.PIPE and then I use subprocess.wait() to w…

Better way to call a chain of functions in python?

I have a chain of operations which needs to occur one after the other and each depends on the previous functions output.Like this:out1 = function1(initial_input) out2 = function2(out1) out3 = function3…

Pytest text annotation for test with tuple of parameters

Im looking for more elegant solution for this kind of problem:def ids(x):if isinstance(x, int):return str(x)elif isinstance(x, str):return x[0]@pytest.mark.parametrize("number, string",[(1, &…

OpenCV wont capture from MacBook Pro iSight

Since a couple of days I cant open my iSight camera from inside an opencv application any more. cap = cv2.VideoCapture(0) returns, and cap.isOpened() returns true. However, cap.grab() just returns fals…