Keras: Optimal epoch selection

2024/11/16 3:27:33

I'm trying to write some logic that selects the best epoch to run a neural network in Keras. My code saves the training loss and the test loss for a set number of epochs and then picks the best fitting epoch according to some logic. The code looks like this:

ini_epochs = 100df_train_loss = DataFrame(data=history.history['loss'], columns=['Train_loss']); 
df_test_loss = DataFrame(data=history.history['val_loss'], columns=['Test_loss']);
df_loss = concat([df_train_loss,df_test_loss], axis=1)Min_loss = max(df_loss['Test_loss'])
for i in range(ini_epochs):Test_loss = df_loss['Test_loss'][i];Train_loss = df_loss['Train_loss'][i]; if Test_loss >  Train_loss and Test_loss < Min_loss:Min_loss = Test_loss;

The idea behind the logic is this; to get the best model, the epoch selected should select the model with the lowest loss value, but it must be above the training loss value to avoid overfitting.

In general, this epoch selection method works OK. However, if the test loss value is below the train loss from the start, then this method picks an epoch of zero (see below). enter image description here

Now I could add another if statement assessing whether the difference between the test and train losses are positive or negative, and then write logic for each case, but what happens if the difference starts positive and then ends up negative. I get confused and haven't been able to write effective code.

So, my questions are:

1) Can you show me how you what code you would write to to account for the situation show in the graph (and for the case where the test and train loss curves cross). I'd say the strategy would be to take the value that with the minimum difference.

2) There is a good chance that I'm going about this the wrong way. I know Keras has a callbacks feature but I don't like the idea of using the save_best_only feature because it can save overfitted models. Any advice on a more efficient epoch selection method would be great.

Answer

Use EarlyStopping which is available in Keras. Early stopping is basically stopping the training once your loss starts to increase (or in other words validation accuracy starts to decrease). use ModelCheckpoint to save the model wherever you want.

from keras.callbacks import EarlyStopping, ModelCheckpointSTAMP = 'simple_lstm_glove_vectors_%.2f_%.2f'%(rate_drop_lstm,rate_drop_dense)
early_stopping =EarlyStopping(monitor='val_loss', patience=5)
bst_model_path = STAMP + '.h5'
model_checkpoint = ModelCheckpoint(bst_model_path, save_best_only=True, save_weights_only=True)hist = model.fit(data_train, labels_train, \validation_data=(data_val, labels_val), \epochs=50, batch_size=256, shuffle=True, \callbacks=[early_stopping, model_checkpoint])model.load_weights(bst_model_path)

refer to this link for more info

https://en.xdnf.cn/q/71715.html

Related Q&A

error in loading pickle

Not able to load a pickle file. I am using python 3.5import pickle data=pickle.load(open("D:\\ud120-projects\\final_project\\final_project_dataset.pkl", "r"))TypeError: a bytes-lik…

How to test if a webpage is an image

Sorry that the title wasnt very clear, basically I have a list with a whole series of urls, with the intention of downloading the ones that are pictures. Is there anyway to check if the webpage is an i…

Generic detail view ProfileView must be called with either an object pk or a slug

Im new to Django 2.0 and im getting this error when visiting my profile page view. Its working with urls like path(users/<int:id>) but i wanted to urls be like path(<username>). Not sure wh…

Python Pandas group datetimes by hour and count row

This is my transaction dataframe, where each row mean a transaction :date station 30/10/2017 15:20 A 30/10/2017 15:45 A 31/10/2017 07:10 A 31/10/2017 07:25 B 31/10/2017 07:55 …

Get Bokehs selection in notebook

Id like to select some points on a plot (e.g. from box_select or lasso_select) and retrieve them in a Jupyter notebook for further data exploration. How can I do that?For instance, in the code below, …

Upload CSV file into Microsoft Azure storage account using python

I am trying to upload a .csv file into Microsoft Azure storage account using python. I have found C-sharp code to write a data to blob storage. But, I dont know C# language. I need to upload .csv file …

GeoDjango: How can I get the distance between two points?

My Profile model has this field:location = models.PointField(geography=True, dim=2, srid=4326)Id like to calculate the distance between the two of these locations (taking into account that the Earth is…

Reading direct access binary file format in Python

Background:A binary file is read on a Linux machine using the following Fortran code:parameter(nx=720, ny=360, nday=365) c dimension tmax(nx,ny,nday),nmax(nx,ny,nday)dimension tmin(nx,ny,nday),nmin(nx,…

Fabric/Python: AttributeError: NoneType object has no attribute partition

Have the following function in fabric for adding user accounts.~/scripts #fab -lPython source codeAvailable commands:OS_TYPEadduser_createcmd Create command line for adding useradduser_getinfo Prom…

Python object @property

Im trying to create a point class which defines a property called "coordinate". However, its not behaving like Id expect and I cant figure out why. class Point:def __init__(self, coord=None…