Tensorflow model prediction is slow

2024/10/13 11:28:26

I have a TensorFlow model with a single Dense layer:

model = tf.keras.Sequential([tf.keras.layers.Dense(2)])
model.build(input_shape=(None, None, 25))

I construct a single input vector in float32:

np_vec = np.array(np.random.randn(1, 1, 25), dtype=np.float32)
vec = tf.cast(tf.convert_to_tensor(np_vec), dtype=tf.float32)

I want to feed that to my model for prediction, but it is very slow. If I call predict or __call__ it takes a really long time, compared to doing the same operation in NumPy.

  1. Call %timeit model.predict(vec):

    10 loops, best of 3: 21.9 ms per loop

  2. Call the model as is %timeit model(vec, training=False):

    1000 loops, best of 3: 806 µs per loop

  3. Perform the multiplication operation myself
    weights = np.array(model.layers[0].get_weights()[0])   
    %timeit np_vec @ weights
    

    1000000 loops, best of 3: 1.27 µs per loop

  4. Perform the multiplication myself using torch

    100000 loops, best of 3: 2.57 µs per loop

Google Colab: https://colab.research.google.com/drive/1RCnTM24RUI4VkykVtdRtRdUVEkAHdu4A?usp=sharing

How can I make my TensorFlow model faster in inference time? Especially because I don't only have a Dense layer, but I also use an LSTM and I don't want to reimplement that in NumPy.

Answer

The whole story lies behind the implementation of the LSTM layer in Keras. The Keras LSTM layer has a default argument unroll=False. This causes the LSTM to run a symbolic loop (loop causes more time). Try adding an extra argument to the LSTM as unroll=True.

tf.keras.layers.LSTM(64, return_sequences=True, stateful=True, unroll=True)

This may result in up to a 2x speed boost up (tested on my machine, using %timeit model(vec, training=False)). However, using unroll=True may cause taking more ram for larger sequences. For more inquiry, please have a look at the Keras LSTM documentation.

https://en.xdnf.cn/q/69536.html

Related Q&A

Pandas Sqlite query using variable

With sqlite3 in Python if I want to make a db query using a variable instead of a fixed command I can do something like this :name = MSFTc.execute(INSERT INTO Symbol VALUES (?) , (name,))And when I tr…

How to remove ^M from a text file and replace it with the next line

So suppose I have a text file of the following contents:Hello what is up. ^M ^M What are you doing?I want to remove the ^M and replace it with the line that follows. So my output would look like:Hello…

Cython: size attribute of memoryviews

Im using a lot of 3D memoryviews in Cython, e.g.cython.declare(a=double[:, :, ::1]) a = np.empty((10, 20, 30), dtype=double)I often want to loop over all elements of a. I can do this using a triple loo…

python asynchronous httprequest

I am trying to use twitter search web service in python. I want to call a web service like:http://search.twitter.com/search.json?q=blue%20angels&rpp=5&include_entities=true&result_type=mix…

What are response codes for 256 and 512 for os.system in python scripting

When i ping servers with os.system in python i get multiple response codes. Command used - os.system("ping -q -c 30 -s SERVERANME")0 - Online 256 - Offline 512 - what does 512 mean ?

Sphinx floating point formatting

Im using Sphinx to generate documentation from code. Does anyone know if there is a way to control the formatting of floating point numbers generated from default arguments. For example if I have the f…

Truncating column width in pandas

Im reading in large csv files into pandas some of them with String columns in the thousands of characters. Is there any quick way to limit the width of a column, i.e. only keep the first 100 characters…

Django - CreateView with multiple models

Can I use Django CreateViews to make a form that add data to multiple tables? Ive created a model called UserMeta to store some additional informations of my users. The ProblemI want to create a view …

Is there a way to pass dictionary in tf.data.Dataset w/ tf.py_func?

Im using tf.data.Dataset in data processing and I want to do apply some python code with tf.py_func.BTW, I found that in tf.py_func, I cannot return a dictionary. Is there any way to do it or workaroun…

How to split only on carriage returns with readlines in python?

I have a text file that contains both \n and \r\n end-of-line markers. I want to split only on \r\n, but cant figure out a way to do this with pythons readlines method. Is there a simple workaround for…