when restoring from a checkpoint, how can I change the data type of the parameters?

2024/10/7 16:21:33

I have a pre-trained Tensorflow checkpoint, where the parameters are all of float32 data type.

How can I load checkpoint parameters as float16? Or is there a way to modify data types of a checkpoint?

Followings is my code snippet that tries to load float32 checkpoint into a float16 graph, and I got the type mismatch error.

import tensorflow as tfA = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float32_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print(sess.run(dense))save_path = saver.save(sess, "tmp.ckpt")tf.reset_default_graph()
A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)
dense = tf.layers.dense(inputs=A, units=3)
varis = tf.trainable_variables(scope=None)
print(varis[1])  # <tf.Variable 'dense/kernel:0' shape=(3, 3) dtype=float16_ref>
assign = dict([(vari.name, vari) for vari in varis])
saver = tf.train.Saver(assign)with tf.Session() as sess:saver.restore(sess, "tmp.ckpt")print(sess.run(dense))pass# errors:
# tensor_name = dense/bias:0; expected dtype half does not equal original dtype float
# tensor_name = dense/kernel:0; expected dtype half does not equal original dtype float
# tensor_name = foo:0; expected dtype half does not equal original dtype float
Answer

Looking a bit into how savers work, seems you can redefine their construction through a builder object. You could for example have a builder that loads values as tf.float32 and then casts them to the actual type of the variable:

import tensorflow as tf
from tensorflow.python.training.saver import BaseSaverBuilderclass CastFromFloat32SaverBuilder(BaseSaverBuilder):# Based on tensorflow.python.training.saver.BulkSaverBuilder.bulk_restoredef bulk_restore(self, filename_tensor, saveables, preferred_shard,restore_sequentially):from tensorflow.python.ops import io_opsrestore_specs = []for saveable in saveables:for spec in saveable.specs:restore_specs.append((spec.name, spec.slice_spec, spec.dtype))names, slices, dtypes = zip(*restore_specs)restore_dtypes = [tf.float32 for _ in dtypes]with tf.device("cpu:0"):restored = io_ops.restore_v2(filename_tensor, names, slices, restore_dtypes)return [tf.cast(r, dt) for r, dt in zip(restored, dtypes)]

Note this assumes that all restored variables are tf.float32. You can adapt the builder appropriately for your use case if necessary, e.g. passing the source type or types in the constructor, etc. With this, you just need to use the above builder in the second saver to get your example to work:

import tensorflow as tfwith tf.Graph().as_default(), tf.Session() as sess:A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float32)dense = tf.layers.dense(inputs=A, units=3)varis = tf.trainable_variables(scope=None)assign = {vari.name: vari for vari in varis}saver = tf.train.Saver(assign)sess.run(tf.global_variables_initializer())print('Value to save:')print(sess.run(dense))save_path = saver.save(sess, "ckpt/tmp.ckpt")with tf.Graph().as_default(), tf.Session() as sess:A = tf.get_variable(name='foo', shape=[3, 3], dtype=tf.float16)dense = tf.layers.dense(inputs=A, units=3)varis = tf.trainable_variables(scope=None)assign = {vari.name: vari for vari in varis}saver = tf.train.Saver(assign, builder=CastFromFloat32SaverBuilder())saver.restore(sess, "ckpt/tmp.ckpt")print('Restored value:')print(sess.run(dense))

Output:

Value to save:
[[ 0.50589913  0.33701038 -0.11597633][ 0.27372625  0.27724823  0.49825498][ 1.0897961  -0.29577428 -0.9173869 ]]
Restored value:
[[ 0.506    0.337   -0.11597][ 0.2737   0.2773   0.4983 ][ 1.09    -0.296   -0.9175 ]]
https://en.xdnf.cn/q/70224.html

Related Q&A

Opencv Python open dng format

I cant figure out how to open a dng file in opencv. The file was created when using the pro options of the Samsung Galaxy S7. The images that are created when using those options are a dng file as well…

VSCode: Set environment variables via script

I have a shell script env.sh containing statements like export ENV_VAR1 = 1. On Linux terminal, I can use . env.sh or source env.sh to set the environment variables. How to set the environment variable…

TensorFlow performance bottleneck on IteratorGetNext

While fiddling around with TensorFlow, I noticed that a relatively simple task (batching some of our 3D accelerometer data and taking the sum of each epoch) was having relatively poor performance. Here…

SQLAlchemy - How to access column names from ResultProxy and write to CSV headers

I am trying to use SQLAlchemy to establish a connection to a PostgreSQL Database, execute a SQL query and print the output of the file to a file in linux. from sqlalchemy import create_engine import ya…

Python Facebook API - cursor pagination

My question involves learning how to retrieve my entire list of friends using Facebooks Python API. The current result returns an object with limited number of friends and a link to the next page. How …

PyQt Irregularly Shaped Windows (e.g. A circular without a border/decorations)

How do I create an irregularly shaped window in PyQt?I found this C++ solution, however I am unsure of how to do that in Python.

default values for variable argument list in Python

Is it possible to set a default value for a variable argument list in Python 3?Something like:def do_it(*args=(2, 5, 21)):passI wonder that a variable argument list is of type tuple but no tuple is ac…

Python error: execute cannot be used while an asynchronous query is underway

How do I prevent the error “ProgrammingError: execute cannot be used while an asynchronous query is underway”? From the docs it says that I should use psycopg2.extras.wait_select if I’m using a cor…

Clearing Django form fields on form validation error?

I have a Django form that allows a user to change their password. I find it confusing on form error for the fields to have the *ed out data still in them.Ive tried several methods for removing form.dat…

How to watch xvfb session thats inside a docker on remote server from my local browser?

Im running a docker (That I built on my own), thats docker running E2E tests. The browser is up and running but I want to have another nice to have feature, I want the ability of watching the session o…