Reusing Tensorflow session in multiple threads causes crash

2024/10/12 17:52:32

Background:

I have some complex reinforcement learning algorithm that I want to run in multiple threads.

Problem

When trying to call sess.run in a thread I get the following error message:

RuntimeError: The Session graph is empty. Add operations to the graph before calling run().

Code reproducing the error:

import tensorflow as tfimport threadingdef thread_function(sess, i):inn = [1.3, 4.5]A = tf.placeholder(dtype=float, shape=(None), name="input")P = tf.Print(A, [A])Q = tf.add(A, P)sess.run(Q, feed_dict={A: inn})def main(sess):thread_list = []for i in range(0, 4):t = threading.Thread(target=thread_function, args=(sess, i))thread_list.append(t)t.start()for t in thread_list:t.join()if __name__ == '__main__':sess = tf.Session()main(sess)

If I run the same code outside a thread it works properly.

Can someone give some insight on how to use Tensorflow sessions properly with python threads?

Answer

Not only can the Session be the current thread default, but also the graph. While you pass in the session and call run on it, the default graph will be a different one.

You can ammend your thread_function like this to make it work:

def thread_function(sess, i):with sess.graph.as_default():inn = [1.3, 4.5]A = tf.placeholder(dtype=float, shape=(None), name="input")P = tf.Print(A, [A])Q = tf.add(A, P)sess.run(Q, feed_dict={A: inn})

However, I wouldn't hope for any significant speedup. Python threading isn't what it means in some other languages, only certain operations, like io, would run in parallel. For CPU heavy operations it's not very useful. Multiprocessing can run code truely in parallel, but you wouldn't share the same session.

https://en.xdnf.cn/q/69543.html

Related Q&A

Conditional column arithmetic in pandas dataframe

I have a pandas dataframe with the following structure:import numpy as np import pandas as pd myData = pd.DataFrame({x: [1.2,2.4,5.3,2.3,4.1], y: [6.7,7.5,8.1,5.3,8.3], condition:[1,1,np.nan,np.nan,1],…

Need some assistance with Python threading/queue

import threading import Queue import urllib2 import timeclass ThreadURL(threading.Thread):def __init__(self, queue):threading.Thread.__init__(self)self.queue = queuedef run(self):while True:host = self…

Python redirect (with delay)

So I have this python page running on flask. It works fine until I want to have a redirect. @app.route("/last_visit") def check_last_watered():templateData = template(text = water.get_last_wa…

Python Selenium. How to use driver.set_page_load_timeout() properly?

from selenium import webdriverdriver = webdriver.Chrome() driver.set_page_load_timeout(7)def urlOpen(url):try:driver.get(url)print driver.current_urlexcept:returnThen I have URL lists and call above me…

Editing both sides of M2M in Admin Page

First Ill lay out what Im trying to achieve in case theres a different way to go about it!I want to be able to edit both sides of an M2M relationship (preferably on the admin page although if needs be …

unstacking shift data (start and end time) into hourly data

I have a df as follows which shows when a person started a shift, ended a shift, the amount of hours and the date worked. Business_Date Number PayTimeStart PayTimeEnd Hours 0 2019-05-24 1…

Tensorflow model prediction is slow

I have a TensorFlow model with a single Dense layer: model = tf.keras.Sequential([tf.keras.layers.Dense(2)]) model.build(input_shape=(None, None, 25))I construct a single input vector in float32: np_ve…

Pandas Sqlite query using variable

With sqlite3 in Python if I want to make a db query using a variable instead of a fixed command I can do something like this :name = MSFTc.execute(INSERT INTO Symbol VALUES (?) , (name,))And when I tr…

How to remove ^M from a text file and replace it with the next line

So suppose I have a text file of the following contents:Hello what is up. ^M ^M What are you doing?I want to remove the ^M and replace it with the line that follows. So my output would look like:Hello…

Cython: size attribute of memoryviews

Im using a lot of 3D memoryviews in Cython, e.g.cython.declare(a=double[:, :, ::1]) a = np.empty((10, 20, 30), dtype=double)I often want to loop over all elements of a. I can do this using a triple loo…