Combination of GridSearchCVs refit and scorer unclear

2024/9/27 15:23:06

I use GridSearchCV to find the best parameters in the inner loop of my nested cross-validation. The 'inner winner' is found using GridSearchCV(scorer='balanced_accuracy'), so as I understand the documentation the model with the highest balanced accuracy on average in the inner folds is the 'best_estimator'. I don't understand what the different arguments for refit in GridSearchCV do in combination with the scorer argument. If refit is True, what scoring function will be used to estimate the performance of that 'inner winner' when refitted to the dataset? The same scoring function that was passed to scorer (so in my case 'balanced_accuracy')? Why can you pass also a string to refit? Does that mean that you can use different functions for 1.) finding the 'inner winner' and 2.) to estimate the performance of that 'inner winner' on the whole dataset?

Answer

When refit=True, sklearn uses entire training set to refit the model. So, there is no test data left to estimate the performance using any scorer function.

If you use multiple scorer in GridSearchCV, maybe f1_score or precision along with your balanced_accuracy, sklearn needs to know which one of those scorer to use to find the "inner winner" as you say. For example with KNN, f1_score might have best result with K=5, but accuracy might be highest for K=10. There is no way for sklearn to know which value of hyper-parameter K is the best.

To resolve that, you can pass one string scorer to refit to specify which of those scorer should ultimately decide best hyper-parameter. This best value will then be used to retrain or refit the model using full dataset. So, when you've got just one scorer, as your case seems to be, you don't have to worry about this. Simply refit=True will suffice.

https://en.xdnf.cn/q/71445.html

Related Q&A

stale association proxy, parent object has gone out of scope with Flask-SQLAlchemy

Ive actually never encountered this error before:sqlalchemy.exc.InvalidRequestError: stale association proxy, parent object has gone out of scopeAfter doing some research, it looks like its because the…

Automatic dictionary key resolution with nested schemas using Marshmallow

I have a Marshmallow schema where an objects use a key to refer to an object that is defined in a dictionary in another part of the structure. I want to have the key automatically resolved when deseria…

Type hinting for Django Model subclass

I have helper function for Django views that looks like this (code below). It returns None or single object that matches given query (e.g. pk=1). from typing import Type, Optionalfrom django.db.models …

How to diff the two files using Python Generator

I have one file of 100GB having 1 to 1000000000000 separated by new line. In this some lines are missing like 5, 11, 19919 etc. My Ram size is 8GB.How to find the missing elements.My idea take another …

How to resolve: attempted relative import with no known parent package [duplicate]

This question already has answers here:Attempted relative import with no known parent package [duplicate](4 answers)Closed 2 years ago.I have a bare bones project structure with mostly empty python fil…

How to create a figure of subplots of grouped bar charts in python

I want to combine multiple grouped bar charts into one figure, as the image below shows. grouped bar charts in a single figure import matplotlib import matplotlib.pyplot as plt import numpy as nplabels…

Python Pillow: Make image progressive before sending to 3rd party server

I have an image that I am uploading using Django Forms, and its available in the variable as InMemoryFile What I want to do is to make it progressive.Code to make an image a progressiveimg = Image.open…

Python - Should one start a new project directly in Python 3.x?

What Python version can you please recommend for a long-term (years) project? Should one use 2.6+ or 3.x is already stable? (only standard libraries are required)UPDATE: according to the answers belo…

Produce random wavefunction

I need to produce a random curve in matplotlib.My x values are from say 1 to 1000 for example. I dont want to generate scattered random y values, I need a smooth curve. Like some kind of very distorted…

How to reference groupby index when using apply, transform, agg - Python Pandas?

To be concrete, say we have two DataFrames:df1:date A 0 12/1/14 3 1 12/1/14 1 2 12/3/14 2 3 12/3/14 3 4 12/3/14 4 5 12/6/14 5df2:B 12/1/14 10 12/2/14 20 12/3/14 10 12/4/14 30 12/5/14 10 …