Extracting Intermediate layer outputs of a CNN in PyTorch

2024/9/29 9:35:06

I am using a Resnet18 model.

ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer3): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer4): Sequential((0): BasicBlock((conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=512, out_features=1000, bias=True)

I want to extract the outputs only from layer2, layer3, layer4 & I don't want the avgpool and fc outputs. How do I achieve this ?

class BasicBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1, padding=1) -> None:super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels,3, stride, padding=padding, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels,3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)if in_channels != out_channels:l1 = nn.Conv2d(in_channels, out_channels,kernel_size=1, stride=stride, bias=False)l2 = nn.BatchNorm2d(out_channels)self.downsample = nn.Sequential(l1, l2)else:self.downsample = Nonedef forward(self, xb):prev = xbx = self.relu(self.bn1(self.conv1(xb)))x = self.bn2(self.conv2(x))if self.downsample is not None:prev = self.downsample(xb)x = x + prevreturn self.relu(x)class CustomResnet(nn.Module):def __init__(self, pretrained:bool=True) -> None:super(CustomResnet, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=7,stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = nn.Sequential(BasicBlock( 64, 64, stride=1), BasicBlock(64, 64))self.layer2 = nn.Sequential(BasicBlock(64, 128, stride=2), BasicBlock(128, 128))self.layer3 = nn.Sequential(BasicBlock(128, 256, stride=2), BasicBlock(256, 256))self.layer4 = nn.Sequential(BasicBlock(256, 512, stride=2), BasicBlock(512, 512))def forward(self, xb):x = self.maxpool(self.relu(self.bn1(self.conv1(xb))))x = self.layer1(x)x2 = x = self.layer2(x)x3 = x = self.layer3(x)x4 = x = self.layer4(x)return [x2, x3, x4]

I guess one solution would be this .. But is there any other way without writing this while lot of code? Also is it possible to load in the pre-trained weights given by torchvision in the above modified ResNet model.


If you know how the forward method is implemented, then you can subclass the model, and override the forward method only.

If you are using the pre-trained weights of a model in PyTorch, then you already have access to the code of the model. So, find where the code of the model is, import it, subclass the model, and override the forward method.

For example:

class MyResNet18(Resnet):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)def forward(self, xb):x = self.maxpool(self.relu(self.bn1(self.conv1(xb))))x = self.layer1(x)x2 = x = self.layer2(x)x3 = x = self.layer3(x)x4 = x = self.layer4(x)return [x2, x3, x4]

and you are done.


