Why would this dataset implementation run out of memory?

2024/10/10 13:21:14

I follow this instruction and write the following code to create a Dataset for images(COCO2014 training set)

from pathlib import Path
import tensorflow as tfdef image_dataset(filepath, image_size, batch_size, norm=True):def preprocess_image(image):image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize(image, image_size)if norm:image /= 255.0  # normalize to [0,1] rangereturn imagedef load_and_preprocess_image(path):image = tf.read_file(path)return preprocess_image(image)all_image_paths = [str(f) for f in Path(filepath).glob('*')]path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)ds = ds.shuffle(buffer_size = len(all_image_paths))ds = ds.repeat()ds = ds.batch(batch_size)ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return dsds = image_dataset(train2014_dir, (256, 256), 4, False)
image = ds.make_one_shot_iterator().get_next('images')
# image is then fed to the network

This code will always run out of both memory(32G) and GPU(11G) and kill the process. Here is the messages shown on terminal. enter image description here

I also spot that the program get stuck at sess.run(opt_op). Where is wrong? How can I fix it?

Answer

The problem is this:

ds = ds.shuffle(buffer_size = len(all_image_paths))

The buffer that Dataset.shuffle() uses is an 'in memory' buffer so you are effectively trying to load the whole dataset in memory.

You have a couple of options (which you can combine) to fix this:

Option 1:

Reduce the buffer size to a much smaller number.

Option 2:

Move the shuffle() statment before the map() statement.

This means we would be shuffling before we load the images therefore we'd just be storing the filenames in the memory buffer for the shuffle rather than storing huge tensors.

https://en.xdnf.cn/q/69890.html

Related Q&A

Paramiko: Creating a PKey from a public key string

Im trying to use the SSH protocol at a low level (i.e. I dont want to start a shell or anything, I just want to pass data). Thus, I am using Paramikos Transport class directly.Ive got the server side d…

Appending to the end of a file in a concurrent environment

What steps need to be taken to ensure that "full" lines are always correctly appended to the end of a file if multiple of the following (example) program are running concurrently.#!/usr/bin/e…

Cython Pickling in Package not found as Error

Im having trouble pickling a Cython class, but only when its defined inside a package. This problem was noted previously online, but they didnt state how it was resolved. There are two components here:…

How can I process images faster with Python?

Id trying to write a script that will detect an RGB value on the screen then click the x,y values. I know how to perform the click but I need to process the image a lot faster than my code below curren…

KFolds Cross Validation vs train_test_split

I just built my first random forest classifier today and I am trying to improve its performance. I was reading about how cross-validation is important to avoid overfitting of data and hence obtain bett…

Using Keras, how can I input an X_train of images (more than a thousand images)?

My application is accident-avoidance car systems using Machine Learning (Convolutional Neural Networks). My images are 200x100 JPG images and the output is an array of 4 elements: the car would move le…

Fastest way to merge two deques

Exist a faster way to merge two deques than this?# a, b are two deques. The maximum length # of a is greater than the current length # of a plus the current length of bwhile len(b):a.append(b.poplef…

Python cannot find shared library in cron

My Python script runs well in the shell. However when I cron it (under my own account) it gives me the following error:/usr/local/bin/python: error while loading shared libraries: libpython2.7.so.1.0: …

Multiple async unit tests fail, but running them one by one will pass

I have two unit tests, if I run them one by one, they pass. If I run them at class level, one pass and the other one fails at response = await ac.post( with the error message: RuntimeError: Event loop…

Pyusb on Windows 7 cannot find any devices

So I installed Pyusb 1.0.0-alpha-1 Under Windows, I cannot get any handles to usb devices.>>> import usb.core >>> print usb.core.find() NoneI do have 1 usb device plugged in(idVendor=…