Transform map to mapPartition using pyspark

2024/10/14 6:17:39

I am trying to load a tensorflow model from disk and predicting the values.

Code

def get_value(row):print("**********************************************")graph = tf.Graph()rowkey = row[0]checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300"print("Loading model................................")with graph.as_default():session_conf = tf.ConfigProto(allow_soft_placement=allow_soft_placement,log_device_placement=log_device_placement)sess = tf.Session(config=session_conf)with sess.as_default():# Load the saved meta graph and restore variablessaver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))saver.restore(sess, checkpoint_file)input_x = graph.get_operation_by_name("X_train").outputs[0]dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]predictions = graph.get_operation_by_name("output/predictions").outputs[0]batch_predictions = sess.run(predictions, {input_x: [row[1]], dropout_keep_prob: 1.0})print(batch_predictions)return (rowkey, batch_predictions)

I have a RDD which consists of a tuple (rowkey, input_vector). I want to use the loaded model to predict the score/class of the input.

Code to call get_value()

result = data_rdd.map(lambda iter: get_value(iter))
result.foreach(print)

The problem is every time I call the map, the model is loaded everytime for each tuple and it takes a lot of time.

I am thinking of loading the model using mapPartitions and then use map to call get_value function. I have no clue as how to convert the code to a mapPartition where I load the tensorflow model only once per parition and reduce the running time.

Thanks in advance.

Answer

I am not sure if I get your question correctly, but we can optimise your code a bit here.

graph = tf.Graph()checkpoint_file = "/home/sahil/Desktop/Relation_Extraction/data/1485336002/checkpoints/model-300"with graph.as_default():session_conf = tf.ConfigProto(allow_soft_placement=allow_soft_placement,log_device_placement=log_device_placement)sess = tf.Session(config=session_conf)s = sess.as_default()
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)input_x = graph.get_operation_by_name("X_train").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
predictions = graph.get_operation_by_name("output/predictions").outputs[0]session_pickle = cPickle.dumps(sess)def get_value(key, vector, session_pickle):sess = cPickle.loads(session_pickle)rowkey = keybatch_predictions = sess.run(predictions, {input_x: [vector], dropout_keep_prob: 1.0})print(batch_predictions)return (rowkey, batch_predictionsresult = data_rdd.map(lambda (key, row): get_value(key=key, vector = row ,  session_pickle = session_pickle))
result.foreach(print)

So you can serialize your tensorflow session. Though I haven't tested your code here. Run this and leave a comment.

https://en.xdnf.cn/q/117988.html

Related Q&A

Module google_auth_httplib2 not found after pip installing google-cloud How can I fix it?

I used pip to install cloud-storage, like this:$ pip install --upgrade google-cloudWhen I started my application, I got an error that said no module named google_auth_httplib2 was found. I used pip lis…

python unbinding/disable key binding after click and resume it later

Im trying to unbind/disable key once its clicked, and resume its function after 2s. But I cant figure out the code for the unbinding. The bind is on window. Heres the code that I tried so far:self.choi…

Extracting information from pandas dataframe

I have the below dataframe. I want to build a rule engine to extract the tokens where the pattern is like Eg. "UNITED STATES" .What is the best way to do it ? Is there anything like regex o…

scipy import error with pyinstaller

I am trying to build a "One File" executable for my project with pyinstaller and a .spec file. The content of the spec file is as follows:# -*- mode: python -*-block_cipher = Nonea = Analysi…

How to compare meaningful level of a set of phrase that describe same concept in NLP?

I have two terms "vehicle" and "motor vehicle". Are there any way to compare the meaningfulness level or ambiguity level of these two in NLP? The outcome should be that "motor…

TypeError: slice indices must be integers or None or have an __index__ method. How to resolve it?

if w<h:normalized_char = np.ones((h, h), dtype=uint8)start = (h-w)/2normalized_char[:, start:start+w] = charelse:normalized_char = np.ones((w, w), dtype=uint8)start = (w-h)/2normalized_char[start:st…

Keras: Understanding the number of trainable LSTM parameters

I have run a Keras LSTM demo containing the following code (after line 166):m = 1 model=Sequential() dim_in = m dim_out = m nb_units = 10model.add(LSTM(input_shape=(None, dim_in),return_sequences=True,…

Updating Labels in Tkinter with for loop

So Im trying to print items in a list dynamically on 10 tkinter Labels using a for loop. Currently I have the following code:labe11 = StringVar() list2_placer = 0 list1_placer = 1 mover = 227 for items…

Paginate results, offset and limit

If I am developing a web service for retrieving some album names of certain artist using an API, and I am asked:The service should give the possibility to paginate results. It should support ofset= and…

Improve code to find prime numbers

I wrote this python code about 3 days ago, and I am stuck here, I think it could be better, but I dont know how to improve it. Can you guys please help me?# Function def is_prime(n):if n == 2 or n == …