Is there a way to pass dictionary in tf.data.Dataset w/ tf.py_func?

2024/10/13 13:13:49

I'm using tf.data.Dataset in data processing and I want to do apply some python code with tf.py_func.

BTW, I found that in tf.py_func, I cannot return a dictionary. Is there any way to do it or workaround?

I have code which looks like below

def map_func(images, labels):"""mapping python function"""# do something# cannot be expressed as a tensor graphreturn {'images': images,'labels': labels,'new_key': new_value}
def tf_py_func(images, labels):return tf.py_func(map_func, [images, labels], [tf.uint8, tf.string], name='blah')return dataset.map(tf_py_func)

===========================================================================

It's been a while and I forgot I asked this question. I solved it other way around and it was so easy that I felt I was almost a stupid. The problem was:

  1. tf.py_func cannot return dictionary.
  2. dataset.map can return dictionary.

And the answer is: map twice.

def map_func(images, labels):"""mapping python function"""# do something# cannot be expressed as a tensor graphreturn processed_images, processed_labelsdef tf_py_func(images, labels):return tf.py_func(map_func, [images, labels], [tf.uint8, tf.string], name='blah')def _to_dict(images, labels):return { 'images': images, 'labels': labels }return dataset.map(tf_py_func).map(_to_dict)
Answer

You could turn the dictionary into a string which you return and then split into a dictionary.

This could look something like this:

return (images + " " + labels + " " + new value)

and then in your other function:

l = map_func(image, label).split(" ")
d['images'] = l[0]
d[
...
https://en.xdnf.cn/q/69527.html

Related Q&A

How to split only on carriage returns with readlines in python?

I have a text file that contains both \n and \r\n end-of-line markers. I want to split only on \r\n, but cant figure out a way to do this with pythons readlines method. Is there a simple workaround for…

Python + MySQLdb executemany

Im using Python and its MySQLdb module to import some measurement data into a Mysql database. The amount of data that we have is quite high (currently about ~250 MB of csv files and plenty of more to c…

How to popup success message in odoo?

I am sending invitation by clicking button after clicking button and successfully sending invitation there is pop up message of successfully invitation send. But the problem is that the main heading of…

How to make ttk.Scale behave more like tk.Scale?

Several Tk widgets also exist in Ttk versions. Usually they have the same general behaviour, but use "styles" and "themes" rather than per-instance appearance attributes (such as bg…

pandas cut multiple columns

I am looking to apply a bin across a number of columns.a = [1, 2, 9, 1, 5, 3] b = [9, 8, 7, 8, 9, 1]c = [a, b]print(pd.cut(c, 3, labels=False))which works great and creates:[[0 0 2 0 1 0] [2 2 2 2 2 0]…

Tracking the number of recursive calls without using global variables in Python

How to track the number of recursive calls without using global variables in Python. For example, how to modify the following function to keep track the number of calls?def f(n):if n == 1:return 1else…

Match string in python regardless of upper and lower case differences [duplicate]

This question already has answers here:Case insensitive in(12 answers)Closed 9 years ago.Im trying to find a match value from a keyword using python. My values are stored in a list (my_list) and in the…

Can celery celerybeat use a Database Scheduler without Django?

I have a small infrastructure plan that does not include Django. But, because of my experience with Django, I really like Celery. All I really need is Redis + Celery to make my project. Instead of usin…

Django UserCreationForm custom fields

I am trying to create form for user registration and add some custom fields. For doing that, Ive subclassed UserCretionForm and added fields as shown in django documentation. Then Ive created function-…

Why val_loss and val_acc are not displaying?

When the training starts, in the run window only loss and acc are displayed, the val_loss and val_acc are missing. Only at the end, these values are showed. model.add(Flatten()) model.add(Dense(512, ac…