How to extract the cell state and hidden state from an RNN model in tensorflow?

2024/10/9 4:21:55

I am new to TensorFlow and have difficulties understanding the RNN module. I am trying to extract hidden/cell states from an LSTM. For my code, I am using the implementation from https://github.com/aymericdamien/TensorFlow-Examples.

# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])# Define weights
weights = {'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))}
biases = {'out': tf.Variable(tf.random_normal([n_classes]))}def RNN(x, weights, biases):# Prepare data shape to match `rnn` function requirements# Current data input shape: (batch_size, n_steps, n_input)# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)# Permuting batch_size and n_stepsx = tf.transpose(x, [1, 0, 2])# Reshaping to (n_steps*batch_size, n_input)x = tf.reshape(x, [-1, n_input])# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)x = tf.split(0, n_steps, x)# Define a lstm cell with tensorflow#with tf.variable_scope('RNN'):lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True)# Get lstm cell outputoutputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)# Linear activation, using rnn inner loop last outputreturn tf.matmul(outputs[-1], weights['out']) + biases['out'], statespred, states = RNN(x, weights, biases)# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Initializing the variables
init = tf.initialize_all_variables()

Now I want to extract the cell/hidden state for each time step in a prediction. The state is stored in a LSTMStateTuple of the form (c,h), which I can find out by evaluating print states. However, trying to call print states.c.eval() (which according to the documentation should give me values in the tensor states.c), yields an error stating that my variables are not initialized even though I am calling it right after I am predicting something. The code for this is here:

# Launch the graph
with tf.Session() as sess:sess.run(init)step = 1# Keep training until reach max iterationsfor v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='RNN'):print v.namewhile step * batch_size < training_iters:batch_x, batch_y = mnist.train.next_batch(batch_size)# Reshape data to get 28 seq of 28 elementsbatch_x = batch_x.reshape((batch_size, n_steps, n_input))# Run optimization op (backprop)sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})print states.c.eval()# Calculate batch accuracyacc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})step += 1print "Optimization Finished!"

and the error message is

InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float[[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

The states are also not visible in tf.all_variables(), only the trained matrix/bias tensors (as described here: Tensorflow: show or save forget gate values in LSTM). I don't want to build the whole LSTM from scratch though since I have the states in the states variable, I just need to call it.

Answer

You may simply collect the values of the states in the same way accuracy is collected.

I guess, pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y}) should work perfectly fine.

https://en.xdnf.cn/q/70059.html

Related Q&A

Python - Nested List to Tab Delimited File?

I have a nested list comprising ~30,000 sub-lists, each with three entries, e.g.,nested_list = [[x, y, z], [a, b, c]].I wish to create a function in order to output this data construct into a tab delim…

How to make sure buildout doesnt use the already installed packages?

I am trying to switch fully to buildout - but our development environment already has lot of stuff installed in /usr/lib/pythonxx/How can I make sure that buildout doesnt use the libraries installed on…

Can python setup.py install use wheels?

I am using setuptools. Is there a way to have the following command use wheels instead of source?python setup.py installIn particular, I have a custom package that requires pandas. While pandas insta…

Getting the last element of a level in a multiindex

I have a dataframe in this format:a b x 1 1 31 1 2 1 1 3 42 1 4 423 1 5 42 1 6 3 1 7 44 1 8 65437 1 9 73 2 1 5656 2 2 7 2 3 5 2 4 5 2 5 34a a…

Sphinx and JavaScript Documentation Workflow [closed]

Closed. This question needs to be more focused. It is not currently accepting answers.Want to improve this question? Update the question so it focuses on one problem only by editing this post.Closed 5…

Getting two characters from string in python [duplicate]

This question already has answers here:Split string every nth character(19 answers)How to iterate over a list in chunks(40 answers)Closed last year.how to get in python from string not one character, b…

I Call API from PYTHON I get the response 406 Not Acceptable

I created a API in my site and Im trying to call an API from python but I always get 406 as a response, however, if I put the url in the browser with the parameters, I can see the correct answerI alrea…

TypeError: unsupported operand type(s) for +=: builtin_function_or_method and int

I am receiving this error (TypeError: unsupported operand type(s) for +=: builtin_function_or_method and int) when trying to run this codetotal_exams = 0 for total_exams in range(1, 100001):sum += tota…

Project Scipy Voronoi diagram from 3d to 2d

I am trying to find a way to calculate a 2d Power Diagram in Python. For this I want to make use of the fact that a 2d power diagram can be interpreted as the intersection of a regular 3d voronoi diagr…

Where can I see the list of built-in wavelet functions that I can pass to scipy.signal.cwt?

scipy.signal.cwts documentation says:scipy.signal.cwt(data, wavelet, widths)wavelet : functionWavelet function, which should take 2 arguments. The first argument is the number of points that the return…