How to load and use a pretained PyTorch InceptionV3 model to classify an image

2024/9/18 18:51:48

I have the same problem as How can I load and use a PyTorch (.pth.tar) model which does not have an accepted answer or one I can figure out how to follow the advice given.

I'm new to PyTorch. I am trying to load the pretrained PyTorch model referenced here: https://github.com/macaodha/inat_comp_2018

I'm pretty sure I am missing some glue.

# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])def image_loader(image_name):"""load image, returns cuda tensor"""image = Image.open(image_name)image = loader(image).float()image = Variable(image, requires_grad=True)image = image.unsqueeze(0)  return image.cpu()  #assumes that you're using CPUimage = image_loader("test-image.jpg")

Produces the error:

in ()----> 1 model.predict(image)

AttributeError: 'dict' object has no attribute 'predict

Answer

Problem

Your model isn't actually a model. When it is saved, it contains not only the parameters, but also other information about the model as a form somewhat similar to a dict.

Therefore, torch.load("iNat_2018_InceptionV3.pth.tar") simply returns dict, which of course does not have an attribute called predict.

model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
type(model)
# dict

Solution

What you need to do first in this case, and in general cases, is to instantiate your desired model class, as per the official guide "Load models".

# First try
from torchvision.models import Inception3
v3 = Inception3()
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

However, directly inputing the model['state_dict'] will raise some errors regarding mismatching shapes of Inception3's parameters.

It is important to know what was changed to the Inception3 after its instantiation. Luckily, you can find that in the original author's train_inat.py.

# What the author has done
model = inception_v3(pretrained=True)
model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
model.aux_logits = False

Now that we know what to change, lets make some modification to our first try.

# Second try
from torchvision.models import Inception3
v3 = Inception3()
v3.fc = nn.Linear(2048, 8142)
v3.aux_logits = False
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

And there you go with successfully loaded model!

https://en.xdnf.cn/q/72602.html

Related Q&A

Append list to pandas DataFrame as new row with index

Despite of the numerous stack overflow questions on appending data to a dataframe I could not really find an answer to the following. I am looking for a straight forward solution to append a list as la…

IPC between Python and C#

I want to pass data between a Python and a C# application in Windows (I want the channel to be bi-directional) In fact I wanna pass a struct containing data about a network packet that Ive captured wit…

Saving matplotlib subplot figure to image file

Im fairly new to matplotlib and am limping along. That said, I havent found an obvious answer to this question.I have a scatter plot I wanted colored by groups, and it looked like plotting via a loop w…

Numpy - Dot Product of a Vector of Matrices with a Vector of Scalars

I have a 3 dimensional data set that I am trying to manipulate in the following way. data.shape = (643, 2890, 10) vector.shape = (643,)I would like numpy to see data as a 643 length 1-D array of 2890x1…

delete node in binary search tree python

The code below is my implement for my binary search tree, and I want to implement delete method to remove the node. Below is my implementation, but when I perform bst = BSTRee() bst.insert(5) bst.inser…

ImportError: cannot import name cbook when using PyCharms Profiler

I am trying to run the PyCharm profiler but I get the following error message:Traceback (most recent call last):File "/home/b3053674/ProgramFiles/pycharm-2017.1.4/helpers/profiler/run_profiler.py&…

Replicating SAS first and last functionality with Python

I have recently migrated to Python as my primary tool for analysis and I am looking to be able to replicate the first. & last. functionality found in SAS. The SAS code would be as follows;data data…

Can mypy track string literals?

Is there anyway to make this work from typing import Literal def foo(bar: Literal["bar"]) -> Literal["foo"]:foo = "foo"return foobar = "bar" foo(bar)Here are …

Lazy loading of attributes

How would you implement lazy load of object attributes, i.e. if attributes are accessed but dont exist yet, some object method is called which is supposed to load these?My first attempt isdef lazyload…

Using pythons property while loading old objects

I have a rather large project, including a class Foo which recently needed to be updated using the @property decorator to create custom getter and setter methods.I also stored several instances of Foo …