Access deprecated attribute validation_data in tf.keras.callbacks.Callback

2024/10/8 6:22:40

I decided to switch from keras to tf.keras (as recommended here). Therefore I installed tf.__version__=2.0.0 and tf.keras.__version__=2.2.4-tf. In an older version of my code (using some older Tensorflow version tf.__version__=1.x.x) I used a callback to compute custom metrics on the entire validation data at the end of each epoch. The idea to do so was taken from here. However, it seems as if the "validation_data" attribute is deprecated so that the following code is not working any longer.

class ValMetrics(Callback):def on_train_begin(self, logs={}):self.val_all_mse = []def on_epoch_end(self, epoch, logs):val_predict = np.asarray(self.model.predict(self.validation_data[0]))val_targ = self.validation_data[1]val_epoch_mse = mse_score(val_targ, val_predict)self.val_epoch_mse.append(val_epoch_mse)# Add custom metrics to the logs, so that we can use them with# EarlyStop and csvLogger callbackslogs["val_epoch_mse"] = val_epoch_mseprint(f"\nEpoch: {epoch + 1}")print("-----------------")print("val_mse:     {:+.6f}".format(val_epoch_mse))return

My current workaround is the following. I simply gave validation_data as an argument to the ValMetrics class :

class ValMetrics(Callback):def __init__(self, validation_data):super(Callback, self).__init__()self.X_val, self.y_val = validation_data

Still I have some questions: Is the "validation_data" attribute really deprecated or can it be found elsewhere? Is there a better way to access the validation data at the end of each epoch than with the above workaround?

Thanks a lot!

Answer

You are right that the argument, validation_data is deprecated as per Tensorflow Callbacks Documentation.

The issue which you are facing has been raised in Github. Related issues are Issue1, Issue2 and Issue3.

None of the above Github Issues is resolved and Your workaround of passing Validation_Data as an argument to Custom Callback is a good one, as per this Github Comment, as many people found it useful.

Specifying the code of workaround below, for the benefit of the Stackoverflow Community, even though it is present in Github.

class Metrics(Callback):def __init__(self, val_data, batch_size = 20):super().__init__()self.validation_data = val_dataself.batch_size = batch_sizedef on_train_begin(self, logs={}):print(self.validation_data)self.val_f1s = []self.val_recalls = []self.val_precisions = []def on_epoch_end(self, epoch, logs={}):batches = len(self.validation_data)total = batches * self.batch_sizeval_pred = np.zeros((total,1))val_true = np.zeros((total))for batch in range(batches):xVal, yVal = next(self.validation_data)val_pred[batch * self.batch_size : (batch+1) * self.batch_size] = np.asarray(self.model.predict(xVal)).round()val_true[batch * self.batch_size : (batch+1) * self.batch_size] = yValval_pred = np.squeeze(val_pred)_val_f1 = f1_score(val_true, val_pred)_val_precision = precision_score(val_true, val_pred)_val_recall = recall_score(val_true, val_pred)self.val_f1s.append(_val_f1)self.val_recalls.append(_val_recall)self.val_precisions.append(_val_precision)return

I will keep following the Github Issues mentioned above and will update the Answer accordingly.

Hope this helps. Happy Learning!

https://en.xdnf.cn/q/70153.html

Related Q&A

How to unpickle a file that has been hosted in a web URL in python

The normal way to pickle and unpickle an object is as follows:Pickle an object:import cloudpickle as cpcp.dump(objects, open("picklefile.pkl", wb))UnPickle an object: (load the pickled file):…

Control tick-labels from multi-level FactorRange

Ive got a three-level bokeh.models.FactorRange which I use to draw tick labels on a vbar-plot. The problem is that there are dozens of factors in total and the lowest-level labels get very cramped.I ca…

PyTorch torch_sparse installation without CUDA

I am new in PyTorch and I have faced one issue, namely I cannot get my torch_sparse module properly installed. In general, I wanted to use module torch_geometric - this I have installed. However, when …

Escaping XPath literal with Python

Im writing a common library to setup an automation test suite with Selenium 2.0 Pythons webdriver.def verify_error_message_present(self, message):try:self.driver.find_element_by_xpath("//span[@cla…

How to return two values in cython cdef without gil (nogil)

I have a function and I am trying to return a number and a vector of ints. What I have is cdef func() nogil:cdef vector[int] vectcdef int a_number...return a_number, vectbut this will give errors like …

Alias for a chain of commands

I have a tool with commands: step1, step2 and step3.I can chain them by calling:$ tool step1 step2 step3I would like to have an alias named all to run all the steps by calling:$ tool allI have found a …

Generate misspelled words (typos)

I have implemented a fuzzy matching algorithm and I would like to evaluate its recall using some sample queries with test data. Lets say I have a document containing the text:{"text": "T…

Get the inverse function of a polyfit in numpy

I have fit a second order polynomial to a number of x/y points in the following way:poly = np.polyfit(x, y, 2)How can I invert this function in python, to get the two x-values corresponding to a speci…

Installing an old version of scikit-learn

Problem StatmentIm trying to run some old python code that requires scikit-learn 18.0 but the current version I have installed is 0.22 and so Im getting a warning/invalid data when I run the code.What …

remove characters from pandas column

Im trying to simply remove the ( and ) from the beginning and end of the pandas column series. This is my best guess so far but it just returns empty strings with () intact. postings[location].replace(…