Can autograd in pytorch handle a repeated use of a layer within the same module?

2024/9/8 10:09:26

I have a layer layer in an nn.Module and use it two or more times during a single forward step. The output of this layer is later inputted to the same layer. Can pytorch's autograd compute the grad of the weights of this layer correctly?

def forward(x):x = self.layer(x)x = self.layer(x)return x

Complete example:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass net(nn.Module):def __init__(self,in_dim,out_dim):super(net,self).__init__()self.layer = nn.Linear(in_dim,out_dim,bias=False)def forward(self,x):x = self.layer(x)x = self.layer(x)return xinput_x = torch.tensor([10.])
label = torch.tensor([5.])
n = net(1,1)
loss_fn = nn.MSELoss()out = n(input_x)
loss = loss_fn(out,label)
n.zero_grad()
loss.backward()for param in n.parameters():w = param.item()g = param.gradprint('Input = %.4f; label = %.4f'%(input_x,label))
print('Weight = %.4f; output = %.4f'%(w,out))
print('Gradient w.r.t. the weight is %.4f'%(g))
print('And it should be %.4f'%(4*(w**2*input_x-label)*w*input_x))

Output:

Input = 10.0000; label = 5.0000
Weight = 0.9472; output = 8.9717
Gradient w.r.t. the weight is 150.4767
And it should be 150.4766

In this example, I have defined a module with only one linear layer (in_dim=out_dim=1 and no bias). w is the weight of this layer; input_x is the input value; label is the desired value. Since the loss is chosen as MSE, the formula for the loss is

((w^2)*input_x-label)^2

Computing by hand, we have

dw/dx = 2*((w^2)*input_x-label)*(2*w*input_x)

The output of my example above shows that autograd gives the same result as computed by hand, giving me a reason to believe that it can work in this case. But in a real application, the layer may have inputs and outputs of higher dimensions, a nonlinear activation function after it, and the neural network could have multiple layers.

What I want to ask is: can I trust autograd to handle such situation, but a lot more complicated than that in my example? How does it work when a layer is called iteratively?

Answer

This will work just fine. From the perspective of the autograd engine this isn't a cyclic application since the resulting computation graph will unwrap the repeated computation as a linear sequence. To illustrate this, for a single layer you might have:

x -----> layer --------+^           ||  2 times  |+-----------+

From the autograd perspective this looks like:

x ---> layer ---> layer ---> layer

Here layer is the same layer copied 3 times over the graph. This means when computing the gradient for the layer's weights they will be accumulated from all the three stages. So when using backward:

x ---> layer ---> layer ---> layer ---> loss_func|lback <--- lback <--- lback <--------+|          |          ||          v          |+------> weights <----+_grad

Here lback represents the local derivative of the layer forward transformation which uses the upstream gradient as an input. Each one adds to the layer's weights_grad.

Recurrent Neural Networks use this repeated application of layers (cells) at their basis. See for example this tutorial about Classifying Names with a Character-Level RNN.

https://en.xdnf.cn/q/72821.html

Related Q&A

Altering numpy function output array in place

Im trying to write a function that performs a mathematical operation on an array and returns the result. A simplified example could be:def original_func(A):return A[1:] + A[:-1]For speed-up and to avoi…

Does the E-factory of lxml support dynamically generated data?

Is there a way of creating the tags dynamically with the E-factory of lxml? For instance I get a syntax error for the following code:E.BODY(E.TABLE(for row_num in range(len(ws.rows)):row = ws.rows[row…

Check if datetime object in pandas has a timezone?

Im importing data into pandas and want to remove any timezones – if theyre present in the data. If the data has a time zone, the following code works successfully: col = "my_date_column" df[…

Extract translator comments with xgettext from JavaScript (in Python mode)

I have a pretty well-working command that extracts strings from all my .js and .html files (which are just Underscore templates). However, it doesnt seem to work for Translator comments.For example, I …

Embedding python + numpy code into C++ dll callback

I am new of python embedding. I am trying to embed python + numpy code inside a C++ callback function (inside a dll)the problem i am facing is the following. if i have:Py_Initialize(); // some python g…

How to parse single file using Python bindings to Clang?

I am writing a simple tool to help with refactoring the source code of our application. I would like to parse C++ code based on wxWidgets library, which defines GUI and produce XML .ui file to use with…

How can I profile a Kivy application?

Im building a game using Kivy. Im encountering performance issues so I decided to profile the program.I tried to run it by:python -m cProfile main.pyThe application screen stays black. After several se…

Set up multiple python installations on windows with tox

I am trying to set up tox on windows to run tests against multiple python installations. I have installed each python in folders named, C:\Python\PythonXX_YY, XX is the python version (e.g. 27) and YY…

How can I change the alpha value dynamically in matplotlib python

Im seeking how to change an alpha value dynamically which are already plotted.This is a kind of sample code I want to implement, but I know it is a wrong writing.import matplotlib.pyplot as pltfig = pl…

How to detect write failure in asyncio?

As a simple example, consider the network equivalent of /dev/zero, below. (Or more realistically, just a web server sending a large file.)If a client disconnects early, you get a barrage of log message…