PyTorch DataLoader uses same random seed for batches run in parallel

2024/7/6 17:40:20

There is a bug in PyTorch/Numpy where when loading batches in parallel with a DataLoader (i.e. setting num_workers > 1), the same NumPy random seed is used for each worker, resulting in any random functions applied being identical across parallelized batches.

Minimal example:

import numpy as np
from torch.utils.data import Dataset, DataLoaderclass RandomDataset(Dataset):def __getitem__(self, index):return np.random.randint(0, 1000, 2)def __len__(self):return 9dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=1, num_workers=3)for batch in dataloader:print(batch)

As you can see, for each parallelized set of batches (3), the results are the same:

# First 3 batches
tensor([[891, 674]])
tensor([[891, 674]])
tensor([[891, 674]])
# Second 3 batches
tensor([[545, 977]])
tensor([[545, 977]])
tensor([[545, 977]])
# Third 3 batches
tensor([[880, 688]])
tensor([[880, 688]])
tensor([[880, 688]])

What is the recommended/most elegant way to fix this? i.e. have each batch produce a different randomization, irrespective of the number of workers.

Answer

It seems this works, at least in Colab:

dataloader = DataLoader(dataset, batch_size=1, num_workers=3, worker_init_fn = lambda id: np.random.seed(id) )

EDIT:

it produces identical output (i.e. the same problem) when iterated over epochs. – iacob

Best fix I have found so far:

...
dataloader = DataLoader(ds, num_workers= num_w, worker_init_fn = lambda id: np.random.seed(id + epoch * num_w ))for epoch in range ( 2 ):for batch in dataloader:print(batch)print()

Still can't suggest closed form, thing depends on a var (epoch) then called. Ideally It must be something like worker_init_fn = lambda id: np.random.seed(id + EAGER_EVAL(np.random.randint(10000) ) where EAGER_EVAL evaluate seed on loader construction, before lambda is passed as parameter. Is it possible in python, I wonder.

https://en.xdnf.cn/q/73226.html

Related Q&A

How to fix 502 Bad Gateway Error in production(Nginx)?

When I tried to upload a big csv file of size about 600MB in my project which is hosted in the digital ocean, it tries to upload but shows 502 Bad Gateway Error (Nginx). The application is a data conve…

Shift theorem in Discrete Fourier Transform

Im trying to solve a problem with python+numpy in which Ive some functions of type that I need to convolve with another function . In order to optimize code, I performed the fft of f and g, I multipli…

Running pudb inside docker container

I prefer pudb for python debugging. I am building python applications that run inside docker container. Does any one know how to make pudb available inside docker container?Thank you

Argparse: defaults from file

I have a Python script which takes a lot of arguments. I currently use a configuration.ini file (read using configparser), but would like to allow the user to override specific arguments using command …

How can access Uploaded File in Google colab

Im new in python and I use Google Colab . I uploaded a train_data.npy into google Colab and then I want to use it . According to this link How to import and read a shelve or Numpy file in Google Colabo…

__add__ to support addition of different types?

Would be very easy to solve had python been a static programming language that supported overloading. I am making a class called Complex which is a representation of complex numbers (I know python has …

How to open .ndjson file in Python?

I have .ndjson file that has 20GB that I want to open with Python. File is to big so I found a way to split it into 50 peaces with one online tool. This is the tool: https://pinetools.com/split-files N…

loading a dataset in python (numpy) when there are variable spaces delimiting columns

I have a big dataset contains numeric data and in some of its rows there are variable spaces delimiting columns, like:4 5 6 7 8 9 2 3 4When I use this line:dataset=numpy.loadtxt("dataset.txt&q…

how to organise files with python27 app engine webapp2 framework

Ive gone through the getting started tut for python27 and app engine: https://developers.google.com/appengine/docs/python/gettingstartedpython27/By the end of the tut, all the the classes are in the sa…

Keras MSE definition

I stumbled across the definition of mse in Keras and I cant seem to find an explanation.def mean_squared_error(y_true, y_pred):return K.mean(K.square(y_pred - y_true), axis=-1)I was expecting the mean …