When is it appropriate to use sample_weights in keras?

2024/9/20 14:57:04

According to this question, I learnt that class_weight in keras is applying a weighted loss during training, and sample_weight is doing something sample-wise if I don't have equal confidence in all the training samples.

So my questions would be,

  1. Is the loss during validation weighted by the class_weight, or is it only weighted during training?
  2. My dataset has 2 classes, and I don't actually have a seriously imbalanced class ditribution. The ratio is approx. 1.7 : 1. Is that neccessary to use class_weight to balance the loss or even use oversampling? Is that OK to leave the slightly imbalanced data as the usual dataset treated?
  3. Can I simply consider sample_weight as the weights I give to each train sample? And my trainig samples can be treated with equal confidence, so I probably I don't need to use this.
Answer
  1. From the keras documentation it says

class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

So class_weight does only affect the loss during traning. I myself have been interested in understanding how the class and sample weights is handled during testing and training. Looking at the keras github repo and the code for metric and loss, it does not seem that either loss or metric is affected by them. The printed values are quite hard to track in the training code like model.fit() and its corresponding tensorflow backend training functions. So I decided to make a test code to test the possible scenarios, see code below. The conclusion is that both class_weight and sample_weight only affect training loss, no effect on any metrics or validation loss. A little surprising as val_sample_weights (which you can specify) seems to do nothing(??).

  1. This types of question always depends on you problem, how skewed the date is and in what way you try to optimize the model. Are you optimizing for accuracy, then as long as the training data is equally skewed as when the model is in production, the best result will be achieved just training without any over/under sampling and/or class weights. If you on the other hand have something where one class is more important (or expensive) than another then you should be weighting the data. For example in fraud prevention, where fraud normally is much more expensive than the income of non-fraud. I would suggest you try out unweighted classes, weighted classes and some under/over-sampling and check which gives the best validation results. Use a validation function (or write your own) that best will compare different models (for-example weighting true-positive, false-positive, true-negative and false-negative differently dependent on cost). A relatively new loss-function that has shown great result at kaggle competitions on skewed data is Focal-loss. Focal-loss reduce the need for over/under-sampling. Unfortunately Focal-loss is not a built inn function in keras (yet), but can be manually programmed.

  2. Yes I think you are correct. I normally use sample_weight for two reasons. 1, the training data have some kind of measuring uncertainty, which if known can be used to weight accurate data more than inaccurate measurements. Or 2, we can weight newer data more than old, forcing the model do adapt to new behavior more quickly, without ignoring valuable old data.

The code for comparing with and without class_weights and sample_weights, while holding the model and everything else static.

import tensorflow as tf
import numpy as npdata_size = 100
input_size=3
classes=3x_train = np.random.rand(data_size ,input_size)
y_train= np.random.randint(0,classes,data_size )
#sample_weight_train = np.random.rand(data_size)
x_val = np.random.rand(data_size ,input_size)
y_val= np.random.randint(0,classes,data_size )
#sample_weight_val = np.random.rand(data_size )inputs = tf.keras.layers.Input(shape=(input_size))
pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)model = tf.keras.models.Model(inputs=inputs, outputs=pred)loss = tf.keras.losses.sparse_categorical_crossentropy
metrics = tf.keras.metrics.sparse_categorical_accuracymodel.compile(loss=loss , metrics=[metrics], optimizer='adam')# Make model static, so we can compare it between different scenarios
for layer in model.layers:layer.trainable = False# base model no weights (same result as without class_weights)
# model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
class_weights={0:1.,1:1.,2:1.}
model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
# which outputs:
> loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100#changing the class weights to zero, to check which loss and metric that is affected
class_weights={0:0,1:0,2:0}
model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
# which outputs:
> loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100#changing the sample_weights to zero, to check which loss and metric that is affected
sample_weight_train = np.zeros(100)
sample_weight_val = np.zeros(100)
model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
# which outputs:
> loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100

There are some small deviations between using weights and not (even when all weights are one), possible due to fit using different backend functions for weighted and unweighted data or due to rounding error?

https://en.xdnf.cn/q/72491.html

Related Q&A

Django South - turning a null=True field into a null=False field

My question is, what is the best practice for turning a null=True field into a null=False field using Django South. Specifically, Im working with a ForeignKey.

Apostrophes are printing out as \x80\x99

import requests from bs4 import BeautifulSoup import resource_url = requests.get(http://www.nytimes.com/pages/business/index.html) div_classes = {class :[ledeStory , story]} title_tags = [h2,h3,h4,h5,h…

Have Sphinx replace docstring text

I am documenting code in Sphinx that resembles this: class ParentClass(object):def __init__(self):passdef generic_fun(self):"""Call this function using /run/ParentClass/generic_fun()&quo…

exit is not a keyword in Python, but no error occurs while using it

I learn that exit is not a keyword in Python by,import keyword print(exit in keyword.kwlist) # Output: FalseBut there is no reminder of NameError: name exit is not defined while using it. The outpu…

Tensorflow Datasets Reshape Images

I want to build a data pipeline using tensorflow dataset. Because each data has different shapes, I cant build a data pipeline.import tensorflow_datasets as tfds import tensorflow as tfdataset_builder …

Why is the python client not receiving SSE events?

I am have a python client listening to SSE events from a server with node.js APIThe flow is I sent an event to the node.js API through call_notification.py and run seevents.py in loop using run.sh(see …

sklearn Pipeline: argument of type ColumnTransformer is not iterable

I am attempting to use a pipeline to feed an ensemble voting classifier as I want the ensemble learner to use models that train on different feature sets. For this purpose, I followed the tutorial avai…

PyQT Window: I want to remember the location it was closed at

I have a QDialog, and when the user closes the QDialog, and reopens it later, I want to remember the location and open the window at the exact same spot. How would I exactly remember that location?

Django Reusable Application Configuration

I have some Django middleware code that connects to a database. I want to turn the middleware into a reusable application ("app") so I can package it for distribution into many other project…

executable made with py2exe doesnt run on windows xp 32bit

I created an executable with py2exe on a 64bit windows 7 machine, and distributed the program.On a windows xp 32bit machine the program refuses to run exhibiting the following behavior:a popup window s…