Is there a proper way to subclass Tensorflows Dataset?

2024/10/2 12:17:48

I was looking at different ways that one can do custom Tensorflow datasets, and I was used to looking at PyTorch's datasets, but when I went to look at Tensorflow's datasets, I saw this example:

class ArtificialDataset(tf.data.Dataset):def _generator(num_samples):# Opening the filetime.sleep(0.03)for sample_idx in range(num_samples):# Reading data (line, record) from the filetime.sleep(0.015)yield (sample_idx,)def __new__(cls, num_samples=3):return tf.data.Dataset.from_generator(cls._generator,output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),args=(num_samples,))

But two questions came up:

  1. This looks like all it does is that when the object is instantiated, the __new__ method just calls the tf.data.Dataset.from_generator static method. So why not just call it? Why is there a point of even subclassing tf.data.Dataset? Are there any methods that are even used from tf.data.Dataset?
  2. Would there be a way to do it like a data generator, where one fills out an __iter__ method while inheriting from tf.data.Dataset? Idk, something like
class MyDataLoader(tf.data.Dataset):def __init__(self, path, *args, **kwargs):super().__init__(*args, **kwargs)self.data = pd.read_csv(path)def __iter__(self):for datum in self.data.iterrows():yield datum

Thank you all very much!

Answer

Question 1

That example is just encapsulating a dataset with a generator in a class. It is inheriting from tf.data.Dataset because from_generator() returns a tf.data.Dataset -based object. However, no methods of tf.data.Dataset are used as seen in the example. Thus, answer to question 1: yes, it can be called straight without using the class.

Question 2

Yes. It can be done like that.

Another, similar way would be to use tf.keras.utils.Sequence like here.

https://en.xdnf.cn/q/70855.html

Related Q&A

Install pyserial Mac OS 10.10?

Attempting to communicate with Arduino serial ports using Python 2.7. Have downloaded pyserial 2.7 (unzipped and put folder pyserial folder in python application folder). Didnt work error message. &quo…

Binning frequency distribution in Python

I have data in the two lists value and freq like this:value freq 1 2 2 1 3 3 6 2 7 3 8 3 ....and I want the output to be bin freq 1-3 6 4-6 2 7-9 6 ...I can write fe…

R style data-axis buffer in matplotlib

R plots automatically set the x and y limits to put some space between the data and the axes. I was wondering if there is a way for matplotlib to do the same automatically. If not, is there a good form…

Python code for the coin toss issues

Ive been writing a program in python that simulates 100 coin tosses and gives the total number of tosses. The problem is that I also want to print the total number of heads and tails.Heres my code:impo…

Preprocess a Tensorflow tensor in Numpy

I have set up a CNN in Tensorflow where I read my data with a TFRecordReader. It works well but I would like to do some more preprocessing and data augmentation than offered by the tf.image functions. …

Os.path : can you explain this behavior?

I love Python because it comes batteries included, and I use built-in functions, a lot, to do the dirty job for me.I have always been using happily the os.path module to deal with file path but recentl…

admin.py for project, not app

How can I specify a project level admin.py?I asked this question some time ago and was just awarded the Tumbleweed award because of the lack of activity on the question! >_<Project:settings.py a…

Python Socket Receive/Send Multi-threading

I am writing a Python program where in the main thread I am continuously (in a loop) receiving data through a TCP socket, using the recv function. In a callback function, I am sending data through the …

numpy array2string applied on huge array, skips central values, ( ... in the middle )

I have array of size (3, 3, 19, 19), which I applied flatten to get array of size 3249.I had to write these values to file along with some other data, so I did following to get the array in string.np.a…

save password as salted hash in mongodb in users collection using python/bcrypt

I want to generate a salted password hash and store it in MongoDB collection called users, like this:users_doc = { "username": "James","password": "<salted_hash_pa…