Reset all weights of Keras model

2024/10/7 4:34:31

I would like to be able to reset the weights of my entire Keras model so that I do not have to compile it again. Compiling the model is currently the main bottleneck of my code. Here is an example of what I mean:

import tensorflow as tf  model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(16, activation='relu'),tf.keras.layers.Dense(10)
])model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])data = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = data.load_data()model.fit(x=x_train, y=y_train, epochs=10)# Reset all weights of model here
# model.reset_all_weights() <----- something like thatmodel.fit(x=x_train, y=y_train, epochs=10)
Answer

I wrote a function that reinitializes weights in tensorflow 2.

def reinitialize(model):for l in model.layers:if hasattr(l,"kernel_initializer"):l.kernel.assign(l.kernel_initializer(tf.shape(l.kernel)))if hasattr(l,"bias_initializer"):l.bias.assign(l.bias_initializer(tf.shape(l.bias)))if hasattr(l,"recurrent_initializer"):l.recurrent_kernel.assign(l.recurrent_initializer(tf.shape(l.recurrent_kernel)))

It took me way longer than it should have to come up with this and i tried many things that failed in my specific use case. IMO this should be a standard TF feature.

https://en.xdnf.cn/q/70286.html

Related Q&A

How to fix NaN or infinity issue for sparse matrix in python?

Im totally new to python. Ive used some code found online and I tried to work on it. So Im creating a text-document-matrix and I want to add some extra features before training a logistic regression mo…

Mutable default argument for a Python namedtuple

I came across a neat way of having namedtuples use default arguments from here.from collections import namedtuple Node = namedtuple(Node, val left right) Node.__new__.__defaults__ = (None, None, None) …

Windows notification with button using python

I need to make a program that alerts me with a windows notification, and I found out that this can be simply done with the following code. I dont care what library I use from win10toast import ToastNo…

numpy IndexError: too many indices for array when indexing matrix with another

I have a matrix a which I create like this:>>> a = np.matrix("1 2 3; 4 5 6; 7 8 9; 10 11 12")I have a matrix labels which I create like this:>>> labels = np.matrix("1;0…

how to use enum in swig with python?

I have a enum declaration as follows:typedef enum mail_ {Out = 0,Int = 1,Spam = 2 } mail;Function:mail status; int fill_mail_data(int i, &status);In the function above, status gets filled up and wi…

What does except Exception as e mean in python? [closed]

Closed. This question needs to be more focused. It is not currently accepting answers.Want to improve this question? Update the question so it focuses on one problem only by editing this post.Closed 4…

Numpy efficient big matrix multiplication

To store big matrix on disk I use numpy.memmap.Here is a sample code to test big matrix multiplication:import numpy as np import timerows= 10000 # it can be large for example 1kk cols= 1000#create some…

Basic parallel python program freezes on Windows

This is the basic Python example from https://docs.python.org/2/library/multiprocessing.html#module-multiprocessing.pool on parallel processingfrom multiprocessing import Pooldef f(x):return x*xif __na…

Formatting multiple worksheets using xlsxwriter

How to copy the same formatting to different sheets of the same Excel file using the xlsxwriter library in Python?The code I tried is: import xlsxwriterimport pandas as pd import numpy as npfrom xlsxw…

Good python library for generating audio files? [closed]

Closed. This question is seeking recommendations for books, tools, software libraries, and more. It does not meet Stack Overflow guidelines. It is not currently accepting answers.We don’t allow questi…