Extract features from last hidden layer Pytorch Resnet18

2024/9/28 5:25:24

I am implementing an image classifier using the Oxford Pet dataset with the pre-trained Resnet18 CNN. The dataset consists of 37 categories with ~200 images in each of them.

Rather than using the final fc layer of the CNN as output to make predictions I want to use the CNN as a feature extractor to classify the pets.

For each image i'd like to grab features from the last hidden layer (which should be before the 1000-dimensional output layer). My model is using Relu activation so I should grab the output just after the ReLU (so all values will be non-negative)

Here is code (following the transfer learning tutorial on Pytorch):

loading data

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])image_datasets = {"train": datasets.ImageFolder('images_new/train', transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize])), "test": datasets.ImageFolder('images_new/test', transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize]))}dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4, pin_memory=True)for x in ['train', 'test']}dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}train_class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train function

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'test']:if phase == 'train':scheduler.step()model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# deep copy the modelif phase == 'test' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# load best model weightsmodel.load_state_dict(best_model_wts)return model

Compute SGD cross-entropy loss

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_featuresprint("number of features: ", num_ftrs)model_ft.fc = nn.Linear(num_ftrs, len(train_class_names))model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=24)

Now how do I get a feature vector from the last hidden layer for each of my images? I know I have to freeze the previous layer so that gradient isn't computed on them but I'm having trouble extracting the feature vectors.

My ultimate goal is to use those feature vectors to train a linear classifier such as Ridge or something like that.

Thanks!

Answer

You can try the approach below. This will work for any layer with only a change of offset.

model_ft = models.resnet18(pretrained=True)
### strip the last layer
feature_extractor = torch.nn.Sequential(*list(model_ft.children())[:-1])
### check this works
x = torch.randn([1,3,224,224])
output = feature_extractor(x) # output now has the features corresponding to input x
print(output.shape)

torch.Size([1, 512, 1, 1])

https://en.xdnf.cn/q/71373.html

Related Q&A

Python Graphs: Latex Math rendering of node labels

I am using the following code to create a pygraphviz graph. But is it possible to make it render latex math equations (see Figure 1)? If not, is there an alternative python library that plots similar…

Given general 3D plane equation

Lets say I have a 3D plane equation:ax+by+cz=dHow can I plot this in python matplotlib?I saw some examples using plot_surface, but it accepts x,y,z values as 2D array. I dont understand how can I conv…

Spark-submit fails to import SparkContext

Im running Spark 1.4.1 on my local Mac laptop and am able to use pyspark interactively without any issues. Spark was installed through Homebrew and Im using Anaconda Python. However, as soon as I try…

Is there a Python API for event-driven Kafka consumer?

I have been trying to build a Flask app that has Kafka as the only interface. For this reason, I want have a Kafka consumer that is triggered when there is new message in the stream of the concerned to…

SWIG python initialise a pointer to NULL

Is it possible to initialise a ptr to NULL from the python side when dealing with SWIG module?For example, say I have wrapped a struct track_t in a swig module m (_m.so), I can create a pointer to the…

Replacing punctuation in a data frame based on punctuation list [duplicate]

This question already has answers here:Fast punctuation removal with pandas(4 answers)Closed 5 years ago.Using Canopy and Pandas, I have data frame a which is defined by:a=pd.read_csv(text.txt)df=pd.Da…

How to import one submodule from different submodule? [duplicate]

This question already has answers here:Relative imports for the billionth time(14 answers)Closed 6 years ago.My project has the following structure:DSTC/st/__init__.pya.pyg.pytb.pydstc.pyHere is a.py i…

How to add dimension to a tensor using Tensorflow

I have method reformat in which using numpy I convert a label(256,) to label(256,2) shape. Now I want to do same operation on a Tensor with shape (256,)My code looks like this (num_labels=2) :--def ref…

Down arrow symbol in matplotlib

I would like to create a plot where some of the points have a downward pointing arrow (see image below). In Astronomy this illustrates that the true value is actually lower than whats measured.Note tha…

Overwrite the previous print value in python?

How can i overwrite the previous "print" value in python?print "hello" print "dude" print "bye"It will output:hello dude byeBut i want to overwrite the value.In…