Loop over a tensor and apply function to each element

2024/9/20 17:49:44

I want to loop over a tensor which contains a list of Int, and apply a function to each of the elements. In the function every element will get the value from a dict of python. I have tried the easy way with tf.map_fn, which will work on add function, such as the following code:

import tensorflow as tfdef trans_1(x):return x+10a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:res = sess.run(b)print(str(res))
# output: [11 12 13]

But the following code throw the KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32 exception:

import tensorflow as tfkv_dict = {1:11, 2:12, 3:13}def trans_2(x):return kv_dict[x]a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:res = sess.run(b)print(str(res))

My tensorflow version is 1.13.1. Thanks ahead.

Answer

There is a simple way to achieve, what you are trying.

The problem is that the function passed to map_fn must have tensors as its parameters and tensor as the return value. However, your function trans_2 takes plain python int as parameter and returns another python int. That's why your code doesn't work.

However, TensorFlow provides a simple way to wrap ordinary python functions, which is tf.py_func, you can use it in your case as follows:

import tensorflow as tfkv_dict = {1:11, 2:12, 3:13}def trans_2(x):return kv_dict[x]def wrapper(x):return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:res = sess.run(b)print(str(res))

you can see I have added a wrapper function, which expects tensor parameter and returns a tensor, that's why it can be used in map_fn. The cast is used because python by default uses 64-bit integers, whereas TensorFlow uses 32-bit integers.

https://en.xdnf.cn/q/72321.html

Related Q&A

How to quickly get the last line from a .csv file over a network drive?

I store thousands of time series in .csv files on a network drive. Before I update the files, I first get the last line of the file to see the timestamp and then I update with data after that timestamp…

Force use of scientific style for basemap colorbar labels

String formatting can by used to specify scientific notation for matplotlib.basemap colorbar labels:cb = m.colorbar(cs, ax=ax1, format=%.4e)But then each label is scientifically notated with the base.I…

VS Code Doesnt Recognize Python Virtual Environment

Im using VS Code on a Mac to write Python code. Ive created a virtual environment named venv inside my project folder and opened VS Code in my project folder. I can see the venv folder in the Explore…

Why codecs.iterdecode() eats empty strings?

Why the following two decoding methods return different results?>>> import codecs >>> >>> data = [, , a, ] >>> list(codecs.iterdecode(data, utf-8)) [ua] >>>…

How to keep NaN in pivot table?

Looking to preserve NaN values when changing the shape of the dataframe.These two questions may be related:How to preserve NaN instead of filling with zeros in pivot table? How to make two NaN as NaN …

Using Pandas df.where on multiple columns produces unexpected NaN values

Given the DataFrameimport pandas as pddf = pd.DataFrame({transformed: [left, right, left, right],left_f: [1, 2, 3, 4],right_f: [10, 20, 30, 40],left_t: [-1, -2, -3, -4],right_t: [-10, -20, -30, -40], }…

Django star rating system and AJAX

I am trying to implement a star rating system on a Django site.Storing the ratings in my models is sorted, as is displaying the score on the page. But I want the users to be able to rate a page (from 1…

Create inheritance graphs/trees for Django templates

Is there any tool out there that would take a directory with a Django application, scan it for templates and draw/print/list a hierarchy of inheritance between templates?Seeing which blocks are being …

Python SVG converter creates empty file

I have some code below that is supposed to convert a SVG image to a PNG. It runs without errors but creates a PNG file that is blank instead of one with the same image as the original SVG. I did find t…

Fastest way to iterate through a pandas dataframe?

How do I run through a dataframe and return only the rows which meet a certain condition? This condition has to be tested on previous rows and columns. For example:#1 #2 #3 #4 1/1/1999 4 …