Why does Keras loss drop dramatically after the first epoch?

2024/9/20 0:25:06

I'm training a U-Net CNN in Keras/Tensorflow and find that loss massively decreases between the last batch of the first epoch, and the first batch of the second epoch:

Epoch 00001: loss improved from inf to 0.07185 - categorical_accuracy: 0.8636
Epoch 2/400: 1/250 [.....................] - loss: 0.0040 - categorical_accuracy: 0.8878

Weirdly categorical accuracy does not drop with loss, but increases slightly. After the drop in loss, it doesn't decrease further, but settles around the lower value. I know this is very little information on the problem, but this behaviour might indicate a common problem I can investigate more?

Some extra info: Optimizer = Adam(lr=1e-4)(Lowering lr didn't seem to help)

Loss: 'class weighted categorical cross entropy', calculated as follows

def class_weighted_categorical_crossentropy(class_weights):def loss_function(y_true, y_pred):# scale preds so that the class probas of each sample sum to 1y_pred /= tf.reduce_sum(y_pred, -1, True)# manual computation of crossentropyepsilon = tf.convert_to_tensor(K.epsilon(), y_pred.dtype.base_dtype)y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)# Multiply each class by its weight:classes_list = tf.unstack(y_true * tf.math.log(y_pred), axis=-1)for i in range(len(classes_list)):classes_list[i] = tf.scalar_mul(class_weights[i], classes_list[i])# Return weighted sum:return - tf.reduce_sum(tf.stack(classes_list, axis=-1), -1)return loss_function

Any ideas/sanity checks are much appreciated!

EDIT:This is the loss plot for training, I didn't have time to neaten it up, its loss plotted per step, not epoch, and you can see the shift to epoch 2 after 250 steps, up until that point the loss curve seems very good, but the shift two epoch two seems strange.

Answer

That sounds right to me. Remember, there is an inverse relationship between loss and accuracy, so as loss decreases, accuracy increases.

My understanding is that, during the first epoch, you basically have a neural network with more-or-less random initial state. After the first epoch, the weights of the neural network will be adjusted often by minimize the loss function (which as previously states is effectively the same as maximizing accuracy). So, at the beginning of the second epoch, your loss should be a lot better (i.e. lower). That means that your neural network is learning.

https://en.xdnf.cn/q/72218.html

Related Q&A

extract strings from a binary file in python

I have a project where I am given a file and i need to extract the strings from the file. Basically think of the "strings" command in linux but im doing this in python. The next condition is …

Installing numpy on Mac to work on AWS Lambda

Is there a way to install numpy on a Mac so that it will work when uploaded to AWS Lambda? I have tried a variety of different ways, including using different pip versions, using easy_install, and fol…

python- how to get the output of the function used in Timer

I want to run a function for 10s then do other stuff. This is my code using Timerfrom threading import Timer import timedef timeout():b=truereturn ba=false t = Timer(10,timeout) t.start()while(a==f…

Create automated tests for interactive shell based on Pythons cmd module

I am building an interactive shell using Python 3 and the cmd module. I have already written simple unit tests using py.test to test the individual functions, such as the do_* functions. Id like to c…

Matplotlib with multiprocessing freeze computer

I have an issue with matplotlib and multiprocessing. I launch a first process, where I display an image and select an area, and close the figure. Then I launch another process, where I call a graph fun…

Pull Tag Value using BeautifulSoup

Can someone direct me as how to pull the value of a tag using BeautifulSoup? I read the documentation but had a hard time navigating through it. For example, if I had:<span title="Funstuff&qu…

What is the practical difference between xml, json, rss and atom when interfacing with Twitter?

Im new to web services and as an introduction Im playing around with the Twitter API using the Twisted framework in python. Ive read up on the different formats they offer, but its still not clear to m…

how to grab from JSON in selenium python

My page returns JSON http response which contains id: 14Is there a way in selenium python to grab this? I searched the web and could not find any solutions. Now I am wondering maybe its just not poss…

Numpy: Array of `arange`s

Is there a way to take...>>> x = np.array([0, 8, 10, 15, 50]).reshape((-1, 1)); ncols = 5...and turn it into...array([[ 0, 1, 2, 3, 4],[ 8, 9, 10, 11, 12],[10, 11, 12, 13, 14],[15, 16, 17…

Understanding model.summary Keras

Im trying to understand model.summary() in Keras. I have the following Convolutional Neural Network. The values of the first Convolution are: conv2d_4 (Conv2D) (None, 148, 148, 16) 448 …