Extracting Intermediate layer outputs of a CNN in PyTorch

2024/9/29 9:35:06

I am using a Resnet18 model.

ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer3): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer4): Sequential((0): BasicBlock((conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=512, out_features=1000, bias=True)
)

I want to extract the outputs only from layer2, layer3, layer4 & I don't want the avgpool and fc outputs. How do I achieve this ?

class BasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1, padding=1) -> None:super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels,3, stride, padding=padding, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels,3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)if in_channels != out_channels:l1 = nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=stride, bias=False)l2 = nn.BatchNorm2d(out_channels)self.downsample = nn.Sequential(l1, l2)else:self.downsample = Nonedef forward(self, xb):prev = xbx = self.relu(self.bn1(self.conv1(xb)))x = self.bn2(self.conv2(x))if self.downsample is not None:prev = self.downsample(xb)x = x + prevreturn self.relu(x)class CustomResnet(nn.Module):def __init__(self, pretrained:bool=True) -> None:super(CustomResnet, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=7,stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = nn.Sequential(BasicBlock( 64, 64, stride=1), BasicBlock(64, 64))self.layer2 = nn.Sequential(BasicBlock(64, 128, stride=2), BasicBlock(128, 128))self.layer3 = nn.Sequential(BasicBlock(128, 256, stride=2), BasicBlock(256, 256))self.layer4 = nn.Sequential(BasicBlock(256, 512, stride=2), BasicBlock(512, 512))def forward(self, xb):x = self.maxpool(self.relu(self.bn1(self.conv1(xb))))x = self.layer1(x)x2 = x = self.layer2(x)x3 = x = self.layer3(x)x4 = x = self.layer4(x)return [x2, x3, x4]

I guess one solution would be this .. But is there any other way without writing this while lot of code? Also is it possible to load in the pre-trained weights given by torchvision in the above modified ResNet model.

Answer

If you know how the forward method is implemented, then you can subclass the model, and override the forward method only.

If you are using the pre-trained weights of a model in PyTorch, then you already have access to the code of the model. So, find where the code of the model is, import it, subclass the model, and override the forward method.

For example:


class MyResNet18(Resnet):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)def forward(self, xb):x = self.maxpool(self.relu(self.bn1(self.conv1(xb))))x = self.layer1(x)x2 = x = self.layer2(x)x3 = x = self.layer3(x)x4 = x = self.layer4(x)return [x2, x3, x4]

and you are done.

https://en.xdnf.cn/q/71225.html

Related Q&A

In Tensorflow, how to unravel the flattened indices obtained by tf.nn.max_pool_with_argmax?

I meet a problem: After I use the tf.nn.max_pool_with_argmax, I obtain the indices i.e. argmax: A Tensor of type Targmax. 4-D. The flattened indices of the max values chosen for each output.How to unr…

How write csv file without new line character in last line?

I have a code like this to write csv file in python import csv with open(eggs.csv, wb) as csvfile:spamwriter = csv.writer(csvfile, delimiter= ,quotechar=|, quoting=csv.QUOTE_MINIMAL)spamwriter.writerow…

Getting tests to parallelize using nose in python

I have a directory with lots of .py files (say test_1.py, test_2.py and so on) Each one of them is written properly to be used with nose. So when I run nosetests script, it finds all the tests in all t…

Python IDLE is not starting on Windows 7

I used to use Python 2.7 and then IDLE was working. I uninstalled it and installed Python 3.1. Right now Idle cannot launch. What should i do to get it running?NOTE: I tried c:\Python31\pythonw.exe c:…

SocketIO + Flask Detect Disconnect

I had a different question here, but realized it simplifies to this: How do you detect when a client disconnects (closes their page or clicks a link) from a page (in other words, the socket connection …

Numpy Array Broadcasting with different dimensions

I a little confused by the broadcasting rules of numpy. Suppose you want to perform an axis-wise scalar product of a higher dimension array to reduce the array dimension by one (basically to perform a …

xml filtering with python

I have a following xml document:<node0><node1><node2 a1="x1"> ... </node2><node2 a1="x2"> ... </node2><node2 a1="x1"> ... </no…

What it really is @client.event? discord.py

A few days ago I became interested in programming discord bots a bit. In the syntax of these programs I noticed a lot of unintelligible issues that I can not find an answer to. Thats why I am asking y…

How to customize virtualenv shell prompt

How do you define a custom prompt to use when activating a Python virtual environment?I have a bash script for activating a virtualenv I use when calling specific Fabric commands. I want the shell pro…

How to get the percent change of values in a dataframe while caring about NaN values?

I have the following DataFrame:Date A 2015-01-01 10 2015-01-02 14 2015-01-05 NaN 2015-01-06 …