TensorFlow performance bottleneck on IteratorGetNext

2024/10/7 16:18:37

While fiddling around with TensorFlow, I noticed that a relatively simple task (batching some of our 3D accelerometer data and taking the sum of each epoch) was having relatively poor performance. Here's the essence of what I had running, once I got the (incredibly nifty!) Timeline functionality up:

import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline# Some dummy functions to compute "features" from the datadef compute_features( data ):feature_functions = [lambda x: test_sum( x, axis = 0 ),lambda x: test_sum( x, axis = 1 ),lambda x: test_sum( x, axis = 2 ),]return tf.convert_to_tensor( [ f( data ) for f in feature_functions ] )def test_sum( data, axis = 0 ):t, v = datareturn tf.reduce_sum( v[:, axis] )# Setup for using Timeline
sess = tf.Session()
run_options = tf.RunOptions( trace_level = tf.RunOptions.FULL_TRACE )
run_metadata = tf.RunMetadata()# Some magic numbers for our dataset
test_sampling_rate = 5000.0
segment_size = int( 60 * test_sampling_rate )# Load the dataset
with np.load( 'data.npz' ) as data:t_raw = data['t']v_raw = data['v']# Build the iterator
full_dataset = tf.data.Dataset.from_tensor_slices( (t_raw, v_raw) ).batch( segment_size )
dataset_iterator = full_dataset.make_initializable_iterator()
next_datum = dataset_iterator.get_next()sess.run( dataset_iterator.initializer )
i = 0
while True:try:print( sess.run( compute_features( next_datum ), options = run_options,run_metadata = run_metadata ) )# Write Timeline data to a file for analysis latertl = timeline.Timeline( run_metadata.step_stats )ctf = tl.generate_chrome_trace_format()with open( 'timeline_{0}.json'.format( i ), 'w' ) as f:f.write( ctf )i += 1except tf.errors.OutOfRangeError:break

Pulling this up in Chrome, I observed that in each iteration, IteratorGetNext was eating up the vast majority of the time:

Screenshot of Chrome displaying the timeline for one iteration

As you can see, the "main" part of the computation is shoved into the tiny blips on the right hand side, while the vast majority of the time for this cycle is stuck in IteratorGetNext.

I'm wondering if I'm missing anything obvious as far as the way I've constructed my graph that would cause performance to degrade so egregiously on this step. I'm a bit stumped as to why this setup is performing so poorly.

Answer

If IteratorGetNext shows up as a large event in the timeline, then your model is bottlenecked on input processing. In this case, the pipeline is quite simple, but it is bottleneck on copying 300,000 elements into a batch. You can move this copy off the critical path by adding a Dataset.prefetch(1) transformation to the dataset definition:

full_dataset = (tf.data.Dataset.from_tensor_slices((t_raw, v_raw)).batch(segment_size).prefetch(1))

For more performance suggestions, see the new Input Pipeline Performance Guide on tensorflow.org.

PS. Calling compute_features(next_datum) in a loop will cause your graph to grow, and the loop to slow down, over time. Rewriting it as follows will be more efficient:

next_computed_features = compute_features(next_datum)
while True:try:print(sess.run(next_computed_features, options=run_options,run_metadata=run_metadata))# ...except tf.errors.OutOfRangeError:break
https://en.xdnf.cn/q/70221.html

Related Q&A

SQLAlchemy - How to access column names from ResultProxy and write to CSV headers

I am trying to use SQLAlchemy to establish a connection to a PostgreSQL Database, execute a SQL query and print the output of the file to a file in linux. from sqlalchemy import create_engine import ya…

Python Facebook API - cursor pagination

My question involves learning how to retrieve my entire list of friends using Facebooks Python API. The current result returns an object with limited number of friends and a link to the next page. How …

PyQt Irregularly Shaped Windows (e.g. A circular without a border/decorations)

How do I create an irregularly shaped window in PyQt?I found this C++ solution, however I am unsure of how to do that in Python.

default values for variable argument list in Python

Is it possible to set a default value for a variable argument list in Python 3?Something like:def do_it(*args=(2, 5, 21)):passI wonder that a variable argument list is of type tuple but no tuple is ac…

Python error: execute cannot be used while an asynchronous query is underway

How do I prevent the error “ProgrammingError: execute cannot be used while an asynchronous query is underway”? From the docs it says that I should use psycopg2.extras.wait_select if I’m using a cor…

Clearing Django form fields on form validation error?

I have a Django form that allows a user to change their password. I find it confusing on form error for the fields to have the *ed out data still in them.Ive tried several methods for removing form.dat…

How to watch xvfb session thats inside a docker on remote server from my local browser?

Im running a docker (That I built on my own), thats docker running E2E tests. The browser is up and running but I want to have another nice to have feature, I want the ability of watching the session o…

Flask WSGI application hangs when import nltk

I followed the instructions here to create a onefile flask-app deployed to apache2 with mod-wsgi on ubuntu. That all works fine when using the original flask app. However, when adding import nltk to th…

python append folder name to filenames in all sub folders

I am trying to append the name of a folder to all filenames within that folder. I have to loop through a parent folder that contain sub folders. I have to do this in Python and not a bat file.Example i…

When ruamel.yaml loads @dataclass from string, __post_init__ is not called

Assume I created a @dataclass class Foo, and added a __post_init__ to perform type checking and processing.When I attempt to yaml.load a !Foo object, __post_init__ is not called.from dataclasses import…