Mini batch training for inputs of variable sizes

2024/9/29 5:23:21

I have a list of LongTensors, and another list of labels. I'm new to PyTorch and RNN's so I'm quite confused as to how to implement minibatch training for the data I have. There is much more to this data, but I want to keep it simple, so I can understand only how to implement the minibatch training part. I'm doing multiclass classification based on the final hidden state of an LSTM/GRU trained on variable length inputs. I managed to get it working with batch size 1(basically SGD) but I'm struggling with implementing minibatches.

Do I have to pad the sequences to the maximum size and create a new tensor matrix of larger size which holds all the elements? I mean like this:

inputs = pad(sequences)
train = DataLoader(inputs, batch_size=batch_size, shuffle=True)
for i, data in train:#do stuff using LSTM and/or GRU models

Is this the accepted way of doing minibatch training on custom data? I couldn't find any tutorials on loading custom data using DataLoader(but I assume that's the way to create batches using pyTorch?)

Another doubt I have is with regards to padding. The reason I'm using LSTM/GRU is because of the variable length of the input. Doesn't padding defeat the purpose? Is padding necessary for minibatch training?

Answer

Yes. The issue with minibatch training on sequences which have different lengths is that you can't stack sequences of different lengths together.

Normally one would do.

for e in range(epochs):sequences = shuffle(sequences)for mb in range(len(sequences)/mb_size):batch = torch.stack(sequences[mb*mb_size:(mb+1)*mb_size])

and then you apply your neural network on your batch. But because your sequences are of different lengths, the torch.stack will fail. So indeed what you have to do is to pad your sequences with zeros so that they all have the same length (at least in a minibatch). So you have 2 options:

1) At the very very beginning, pad all your sequences with initial zeros so that they all have the same length as your longest sequence of all your data.

OR

2) On the fly, for each minibatch, before stacking the sequences together, pad all the sequences that will go into the minibatch with initial zeros so that they all have the same length as the longest sequence of the minibatch.

https://en.xdnf.cn/q/71251.html

Related Q&A

Python gTTS, is there a way to change the speed of the speech

It seems that on gTTS there is no option for changing the speech of the text-to-speech apart from the slow argument. I would like to speed up the sound by 5%. Any suggestion on how I can do it? Best.t…

Change QLabel text dynamically in PyQt4

My question is: how can I change the text in a label? The label is inside a layout, but setText() does not seem to work - maybe I am not doing it right. Here is my code:this is the Main windows GUI, t…

Setting figure size to be larger than screen size in matplotlib

Im trying to create figures in matplotlib that read nicely in a journal article. I have some larger figures (with subfigures) that Id like to take up nearly an entire page in portrait mode (specificall…

Tensorflow 0.7.1 with Cuda Toolkit 7.5 and cuDNN 7.0

I recently tried to upgrade my Tensorflow installation from 0.6 to 0.7.1 (Ubuntu 15.10, Python 2.7) because it is described to be compatible with more up-to-date Cuda libraries. Everything works well i…

How to export tensor board data?

In the tensorborads README.md, it ask me to do like this:How can I export data from TensorBoard?If youd like to export data to visualize elsewhere (e.g. iPython Notebook), thats possible too. You can…

Releasing Python GIL while in C++ code

Ive got a library written in C++ which I wrap using SWIG and use in python. Generally there is one class with few methods. The problem is that calling these methods may be time consuming - they may han…

How to include the default TEMPLATE_CONTEXT_PROCESSORS in the new TEMPLATES setting in Django 1.10

Im upgrading a project to Django 1.10 and it has code like the following:from django.conf.global_settings import TEMPLATE_CONTEXT_PROCESSORS as TCPTEMPLATE_CONTEXT_PROCESSORS = TCP + (django.template.c…

Selecting best range of values from histogram curve

Scenario :I am trying to track two different colored objects. At the beginning, user is prompted to hold the first colored object (say, may be a RED) at a particular position in front of camera (marked…

dash_bootstrap_components installed succesfully but no recognised

I have my dash working perfectly. I have installed dash_bootstrap_components to give style to my dash. I wrote pip install dash-bootstrap-components and was perfectly installed. But when I run the app,…

Efficient updates of image plots in Bokeh for interactive visualization

Im trying to create a smooth interactive visualization of different slices of a muldimensional array using Bokeh. The data in the slices changes according to the user interaction and thus has to be upd…