How to extract feature vector from single image in Pytorch?

2024/11/15 9:15:00

I am attempting to understand more about computer vision models, and I'm trying to do some exploring of how they work. In an attempt to understand how to interpret feature vectors more I'm trying to use Pytorch to extract a feature vector. Below is my code that I've pieced together from various places.

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Imageimg=Image.open("Documents/01235.png")# Load the pretrained model
model = models.resnet18(pretrained=True)# Use the model object to select the desired layer
layer = model._modules.get('avgpool')# Set model to evaluation mode
model.eval()transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])def get_vector(image_name):# Load the image with Pillow libraryimg = Image.open("Documents/Documents/Driven Data Competitions/Hateful Memes Identification/data/01235.png")# Create a PyTorch Variable with the transformed imaget_img = transforms(img)# Create a vector of zeros that will hold our feature vector# The 'avgpool' layer has an output size of 512my_embedding = torch.zeros(512)# Define a function that will copy the output of a layerdef copy_data(m, i, o):my_embedding.copy_(o.data)# Attach that function to our selected layerh = layer.register_forward_hook(copy_data)# Run the model on our transformed imagemodel(t_img)# Detach our copy function from the layerh.remove()# Return the feature vectorreturn my_embeddingpic_vector = get_vector(img)

When I do this I get the following error:

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 224, 224] instead

I'm sure this is an elementary error, but I can't seem to figure out how to fix this. It was my impression that the "totensor" transformation would make my data 4-d, but it seems it's either not working correctly or I'm misunderstanding it. Appreciate any help or resources I can use to learn more about this!

Answer

All the default nn.Modules in pytorch expect an additional batch dimension. If the input to a module is shape (B, ...) then the output will be (B, ...) as well (though the later dimensions may change depending on the layer). This behavior allows efficient inference on batches of B inputs simultaneously. To make your code conform you can just unsqueeze an additional unitary dimension onto the front of t_img tensor before sending it into your model to make it a (1, ...) tensor. You will also need to flatten the output of layer before storing it if you want to copy it into your one-dimensional my_embedding tensor.

A couple of other things:

  • You should infer within a torch.no_grad() context to avoid computing gradients since you won't be needing them (note that model.eval() just changes the behavior of certain layers like dropout and batch normalization, it doesn't disable construction of the computation graph, but torch.no_grad() does).

  • I assume this is just a copy paste issue but transforms is the name of an imported module as well as a global variable.

  • o.data is just returning a copy of o. In the old Variable interface (circa PyTorch 0.3.1 and earlier) this used to be necessary, but the Variable interface was deprecated way back in PyTorch 0.4.0 and no longer does anything useful; now its use just creates confusion. Unfortunately, many tutorials are still being written using this old and unnecessary interface.

Updated code is then as follows:

import torch
import torchvision
import torchvision.models as models
from PIL import Imageimg = Image.open("Documents/01235.png")# Load the pretrained model
model = models.resnet18(pretrained=True)# Use the model object to select the desired layer
layer = model._modules.get('avgpool')# Set model to evaluation mode
model.eval()transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])def get_vector(image):# Create a PyTorch tensor with the transformed imaget_img = transforms(image)# Create a vector of zeros that will hold our feature vector# The 'avgpool' layer has an output size of 512my_embedding = torch.zeros(512)# Define a function that will copy the output of a layerdef copy_data(m, i, o):my_embedding.copy_(o.flatten())                 # <-- flatten# Attach that function to our selected layerh = layer.register_forward_hook(copy_data)# Run the model on our transformed imagewith torch.no_grad():                               # <-- no_grad contextmodel(t_img.unsqueeze(0))                       # <-- unsqueeze# Detach our copy function from the layerh.remove()# Return the feature vectorreturn my_embeddingpic_vector = get_vector(img)
https://en.xdnf.cn/q/71475.html

Related Q&A

Which language should I use for Artificial intelligence on web projects

I have to do one project for my thesis involving Artificial intelligence, collaborative filtering and machine learning methods.I only know PHP/mysq/JS, and there is not much AI stuff examples in PHP.Th…

Scrapy with selenium, webdriver failing to instantiate

I am trying to use selenium/phantomjs with scrapy and Im riddled with errors. For example, take the following code snippet:def parse(self, resposne):while True:try:driver = webdriver.PhantomJS()# do so…

How do I enable TLS on an already connected Python asyncio stream?

I have a Python asyncio server written using the high-level Streams API. I want to enable TLS on an already established connection, as in STARTTLS in the SMTP and IMAP protocols. The asyncio event loop…

Validate with three xml schemas as one combined schema in lxml?

I am generating an XML document for which different XSDs have been provided for different parts (which is to say, definitions for some elements are in certain files, definitions for others are in other…

An unusual Python syntax element frequently used in Matplotlib

One proviso: The syntax element at the heart of my Question is in the Python language; however, this element appears frequently in the Matplotlib library, which is the only context i have seen it. So w…

Control the power of a usb port in Python

I was wondering if it could be possible to control the power of usb ports in Python, using vendor ids and product ids. It should be controlling powers instead of just enabling and disabling the ports. …

Threads and local proxy in Werkzeug. Usage

At first I want to make sure that I understand assignment of the feature correct. The local proxy functionality assigned to share a variables (objects) through modules (packages) within a thread. Am I …

Unable to use google-cloud in a GAE app

The following line in my Google App Engine app (webapp.py) fails to import the Google Cloud library:from google.cloud import storageWith the following error:ImportError: No module named google.cloud.st…

Multiple thermocouples on raspberry pi

I am pretty new to the GPIO part of the raspberry Pi. When I need pins I normally just use Arduino. However I would really like this project to be consolidated to one platform if possible, I would li…

Strange behaviour when mixing abstractmethod, classmethod and property decorators

Ive been trying to see whether one can create an abstract class property by mixing the three decorators (in Python 3.9.6, if that matters), and I noticed some strange behaviour. Consider the following …