How to use properly Tensorflow Dataset with batch?

2024/9/25 5:30:58

I am new to Tensorflow and deep learning, and I am struggling with the Dataset class. I tried a lot of things and I can’t find a good solution.

What I am trying

I have a large amount of images (500k+) to train my DNN with. This is a denoising autoencoder so I have a pair of each image. I am using the dataset class of TF to manage the data, but I think I use it really badly.

Here is how I load the filenames in a dataset:

class Data:
def __init__(self, in_path, out_path):self.nb_images = 512self.test_ratio = 0.2self.batch_size = 8# load filenames in input and outputsinputs, outputs, self.nb_images = self._load_data_pair_paths(in_path, out_path, self.nb_images)self.size_training = self.nb_images - int(self.nb_images * self.test_ratio)self.size_test = int(self.nb_images * self.test_ratio)# split arrays in training / validationtest_data_in, training_data_in = self._split_test_data(inputs, self.test_ratio)test_data_out, training_data_out = self._split_test_data(outputs, self.test_ratio)# transform array to tf.data.Datasetself.train_dataset = tf.data.Dataset.from_tensor_slices((training_data_in, training_data_out))self.test_dataset = tf.data.Dataset.from_tensor_slices((test_data_in, test_data_out))

I have a function to call at each epoch that will prepare the dataset. It shuffles the filenames, and transforms filenames to images and batch data.

def get_batched_data(self, seed, batch_size):nb_batch = int(self.size_training / batch_size)def img_to_tensor(path_in, path_out):img_string_in = tf.read_file(path_in)img_string_out = tf.read_file(path_out)im_in = tf.image.decode_jpeg(img_string_in, channels=1)im_out = tf.image.decode_jpeg(img_string_out, channels=1)return im_in, im_outt_datas = self.train_dataset.shuffle(self.size_training, seed=seed)t_datas = t_datas.map(img_to_tensor)t_datas = t_datas.batch(batch_size)return t_datas

Now during the training, at each epoch we call the get_batched_data function, make an iterator, and run it for each batch, then feed the array to the optimizer operation.

for epoch in range(nb_epoch):sess_iter_in = tf.Session()sess_iter_out = tf.Session()batched_train = data.get_batched_data(epoch)iterator_train = batched_train.make_one_shot_iterator()in_data, out_data = iterator_train.get_next()total_batch = int(data.size_training / batch_size)for batch in range(total_batch):print(f"{batch + 1} / {total_batch}")in_images = sess_iter_in.run(in_data).reshape((-1, 64, 64, 1))out_images = sess_iter_out.run(out_data).reshape((-1, 64, 64, 1))sess.run(optimizer, feed_dict={inputs: in_images,outputs: out_images})

What do I need ?

I need to have a pipeline that loads only the images of the current batch (otherwise it will not fit in memory) and I want to shuffle the dataset in a different way for each epoch.

Questions and problems

First question, am I using the Dataset class in a good way? I saw very different things on the internet, for example in this blog post the dataset is used with a placeholder and fed during the learning with the datas. It seems strange because the data are all in an array, so loaded in memory. I don't see the point of using tf.data.dataset in this case.

I found solution by using repeat(epoch) on the dataset, like this, but the shuffle will not be different for each epoch in this case.

The second problem with my implementation is that I have an OutOfRangeError in some cases. With a small amount of data (512 like in the exemple) it works fine, but with a bigger amount of data, the error occurs. I thought it was because of a bad calculation of the number of batch due to bad rounding, or when the last batch has a smaller amount of data, but it happens in batch 32 out of 115... Is there any way to know the number of batch created after a batch(n) call on dataset?

Sorry for this loooonng question, but I've been struggling with this for a few days.

Answer

As far as I know, Official Performance Guideline is the best teaching material to make input pipelines.

I want to shuffle the dataset in a different way for each epoch.

Using shuffle() and repeat(), you can get different shuffle pattern for each epochs. You can confirm it with the following code

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4])
dataset = dataset.shuffle(4)
dataset = dataset.repeat(3)iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()with tf.Session() as sess:for i in range(10):print(sess.run(x))

You can also use tf.contrib.data.shuffle_and_repeat as the mentioned by the above official page.

There are some problems in your code outside of creating data pipelines. You confuse graph construction with graph execution. You are repeating to create data input pipeline, so there are many redundant input pipelines as many as epochs. You can observe the redundant pipelines by Tensorboard.

You should place your graph construction code outside of loop as the following code (pseudo code)

batched_train = data.get_batched_data()
iterator = batched_train.make_initializable_iterator()
in_data, out_data = iterator_train.get_next()for epoch in range(nb_epoch):# reset iterator's statesess.run(iterator.initializer)try:while True:in_images = sess.run(in_data).reshape((-1, 64, 64, 1))out_images = sess.run(out_data).reshape((-1, 64, 64, 1))sess.run(optimizer, feed_dict={inputs: in_images,outputs: out_images})except tf.errors.OutOfRangeError:pass

Moreover there are some unimportant inefficient code. You loaded a list of file path with from_tensor_slices(), so the list was embedded in your graph. (See https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays for detail)

You would be better off using prefetch, and decreasing sess.run call by combining your graph.

https://en.xdnf.cn/q/71613.html

Related Q&A

How to handle a huge stream of JSON dictionaries?

I have a file that contains a stream of JSON dictionaries like this:{"menu": "a"}{"c": []}{"d": [3, 2]}{"e": "}"}It also includes nested dict…

datatype for handling big numbers in pyspark

I am using spark with python.After uploading a csv file,I needed to parse a column in a csv file which has numbers that are 22 digits long. For parsing that column I used LongType() . I used map() func…

Multi processing code repeatedly runs

So I wish to create a process using the python multiprocessing module, I want it be part of a larger script. (I also want a lot of other things from it but right now I will settle for this)I copied the…

Why use os.setsid() in Python?

I know os.setsid() is to change the process(forked) group id to itself, but why we need it?I can see some answer from Google is: To keep the child process running while the parent process exit.But acc…

How to apply different aggregation functions to same column by using pandas Groupby

It is clear when doingdata.groupby([A,B]).mean()We get something multiindex by level A and B and one column with the mean of each grouphow could I have the count(), std() simultaneously ?so result loo…

Can not connect to an abstract unix socket in python

I have a server written in c++ which creates and binds to an abstract unix socket with a namespace address of "\0hidden". I also have a client which is written in c++ also and this client can…

Pandas display extra unnamed columns for an excel file

Im working on a project using pandas library, in which I need to read an Excel file which has following columns: invoiceid, locationid, timestamp, customerid, discount, tax,total, subtotal, productid, …

Modifying the weights and biases of a restored CNN model in TensorFlow

I have recently started using TensorFlow (TF), and I have come across a problem that I need some help with. Basically, Ive restored a pre-trained model, and I need to modify the weights and biases of o…

Flask SQLAlchemy paginate over objects in a relationship

So I have two models: Article and Tag, and a m2m relationship which is properly set.I have a route of the kind articles/tag/ and I would like to display only those articles related to that tagI have so…

generating correlated numbers in numpy / pandas

I’m trying to generate simulated student grades in 4 subjects, where a student record is a single row of data. The code shown here will generate normally distributed random numbers with a mean of 60 …