How to remove the last FC layer from a ResNet model in PyTorch?

2024/11/19 7:40:37

I am using a ResNet152 model from PyTorch. I'd like to strip off the last FC layer from the model. Here's my code:

from torchvision import datasets, transforms, models
model = models.resnet152(pretrained=True)
print(model)

When I print the model, the last few lines look like this:

    (2):  Bottleneck((conv1):  Conv2d(2048,  512,  kernel_size=(1,  1),  stride=(1,  1),  bias=False)(bn1):  BatchNorm2d(512,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)(conv2):  Conv2d(512,  512,  kernel_size=(3,  3),  stride=(1,  1),  padding=(1,  1),  bias=False)(bn2):  BatchNorm2d(512,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)(conv3):  Conv2d(512,  2048,  kernel_size=(1,  1),  stride=(1,  1),  bias=False)(bn3):  BatchNorm2d(2048,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)(relu):  ReLU(inplace)))(avgpool):  AvgPool2d(kernel_size=7,  stride=1,  padding=0)(fc):  Linear(in_features=2048,  out_features=1000,  bias=True)
)

I want to remove that last fc layer from the model.

I found an answer here on SO (How to convert pretrained FC layers to CONV layers in Pytorch), where mexmex seems to provide the answer I'm looking for:

list(model.modules()) # to inspect the modules of your model
my_model = nn.Sequential(*list(model.modules())[:-1]) # strips off last linear layer

So I added those lines to my code like this:

model = models.resnet152(pretrained=True)
list(model.modules()) # to inspect the modules of your model
my_model = nn.Sequential(*list(model.modules())[:-1]) # strips off last linear layer
print(my_model)

But this code doesn't work as advertised -- as least not for me. The rest of this post is a detailed explanation of why that answer doesn't work so this question doesn't get closed as a duplicate.

First, the printed model is nearly 5x larger than before. I see the same model as before, but followed by what appears to be a repeat of the model, but perhaps flattened.

    (2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace)))(avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)(fc): Linear(in_features=2048, out_features=1000, bias=True)
)
(1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(5): Sequential(. . . this goes on for ~1600 more lines . . .(415): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(416): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(417): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(418): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(419): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(420): ReLU(inplace)(421): AvgPool2d(kernel_size=7, stride=1, padding=0)
)

Second, the fc layer is still there -- and the Conv2D layer after it looks just like the first layer of ResNet152.

Third, if I try to invoke my_model.forward(), pytorch complains about a size mismatch. It expects size [1, 3, 224, 224], but the input was [1, 1000]. So it looks like a copy of the entire model (minus the fc layer) is getting appended to the original model.

Bottom line, the only answer I found on SO doesn't actually work.

Answer

For ResNet model, you can use children attribute to access layers since ResNet model in pytorch consist of nn modules. (Tested on pytorch 0.4.1)

model = models.resnet152(pretrained=True)
newmodel = torch.nn.Sequential(*(list(model.children())[:-1]))
print(newmodel)

Update: Although there is not an universal answer for the question that can work on all pytorch models, it should work on all well structured ones. Existing layers you add to your model (such as torch.nn.Linear, torch.nn.Conv2d, torch.nn.BatchNorm2d...) all based on torch.nn.Module class. And if you implement a custom layer and add that to your network you should inherit it from pytorch's torch.nn.Module class. As written in documentation, children attribute lets you access the modules of your class/model/network.

def children(self):r"""Returns an iterator over immediate children modules.  

Update: It is important to note that children() returns "immediate" modules, which means if last module of your network is a sequential, it will return whole sequential.

https://en.xdnf.cn/q/26464.html

Related Q&A

Python Twitter library: which one? [closed]

As it currently stands, this question is not a good fit for our Q&A format. We expect answers to be supported by facts, references, or expertise, but this question will likely solicit debate, argum…

How do I use url_for if my method has multiple route annotations?

So I have a method that is accessible by multiple routes:@app.route("/canonical/path/") @app.route("/alternate/path/") def foo():return "hi!"Now, how can I call url_for(&q…

Shortest Python Quine?

Python 2.x (30 bytes): _=_=%r;print _%%_;print _%_Python 3.x (32 bytes) _=_=%r;print(_%%_);print(_%_)Is this the shortest possible Python quine, or can it be done better? This one seems to improve o…

How to find all python installations on mac os x and uninstall all but the native OS X installation

I have installed a few versions on my MacBook for different projects and have only now realized what a mistake that was. I have used homebrew to install it, installed it via pythons website (Python 2.7…

Who runs the callback when using apply_async method of a multiprocessing pool?

Im trying to understand a little bit of whats going on behind the scenes when using the apply_sync method of a multiprocessing pool. Who runs the callback method? Is it the main process that called ap…

Difference between hash() and id()

I have two user-defined objects, say a and b. Both these objects have the same hash values. However, the id(a) and id(b) are unequal.Moreover, >>> a is b False >>> a == b TrueFrom th…

get class name for empty queryset in django

I have empty queryset of model Studentstudents = Students.objects.all()If the above queryset is empty, then how can i get the model(class name)?How can i get the model name for empty queryset?EDIT:Ho…

`Sudo pip install matplotlib` fails to find freetype headers. [OS X Mavericks / 10.9] [closed]

Closed. This question does not meet Stack Overflow guidelines. It is not currently accepting answers.This question does not appear to be about a specific programming problem, a software algorithm, or s…

Parallel processing from a command queue on Linux (bash, python, ruby... whatever)

I have a list/queue of 200 commands that I need to run in a shell on a Linux server. I only want to have a maximum of 10 processes running (from the queue) at once. Some processes will take a few secon…

How do I select and store columns greater than a number in pandas? [duplicate]

This question already has answers here:How do I select rows from a DataFrame based on column values?(17 answers)Closed 28 days ago.I have a pandas DataFrame with a column of integers. I want the rows …