In keras\tensorflow, How adding CNN layers to last layer of ResNet50V2 that pre-train on imagenet

2024/9/20 7:22:11

I am trying to drop the last layer and add a simple CNN instead like the following,

model = Sequential()
base_model = ResNet50V2(include_top=False, weights="imagenet", input_shape=input_shape, pooling="avg")
base_model.trainable = False
model = Sequential()
model.add(base_model)# I want to add the following CNN
model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
model.add(MaxPooling2D((2, 2), padding='same'))
model.add(Flatten())
model.add(Dense(128, activation='relu', kernel_initializer='he_uniform')) 
model.add(Dense(1, activation='sigmoid'))

I don't what I am missing in making this connection that I get the following error,

model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))File "/project/6035234/npiran/classification/venv_clf/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py", line 522, in _method_wrapperresult = method(self, *args, **kwargs)File "/project/6035234/npiran/classification/venv_clf/lib/python3.7/site-packages/tensorflow/python/keras/engine/sequential.py", line 228, in addoutput_tensor = layer(self.outputs[0])File "/project/6035234/npiran/classification/venv_clf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 970, in __call__input_list)File "/project/6035234/npiran/classification/venv_clf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1108, in _functional_construction_callinputs, input_masks, args, kwargs)File "/project/6035234/npiran/classification/venv_clf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 840, in _keras_tensor_symbolic_callreturn self._infer_output_signature(inputs, args, kwargs, input_masks)File "/project/6035234/npiran/classification/venv_clf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 878, in _infer_output_signatureself._maybe_build(inputs)File "/project/6035234/npiran/classification/venv_clf/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 2600, in _maybe_buildself.input_spec, inputs, self.name)File "/project/6035234/npiran/classification/venv_clf/lib/python3.7/site-packages/tensorflow/python/keras/engine/input_spec.py", line 235, in assert_input_compatibilitystr(tuple(shape)))
ValueError: Input 0 of layer conv2d is incompatible with the layer: : expected min_ndim=4, found ndim=2. Full shape received: (None, 2048)
Answer

Update Version

If you want to use CNN that exists in ResNet50V2, Instead of using base_model. trainable = False for all layers, do like below and train some layers. Then use option_2 and pass it to tf.keras.layers.Flatten().

for idx, layer in enumerate(base_model.layers):print(f'The name of {idx} layers is   {layer.name}')
# The name of 142 layers is   conv4_block6_out
# The name of 143 layers is   conv5_block1_1_conv
# The name of 144 layers is   conv5_block1_1_bn
...
# The name of 173 layers is   conv5_block3_add
# The name of 174 layers is   conv5_block3_out
# The name of 175 layers is   avg_poolfor layer in base_model.layers[:143]:layer.trainable = Falsefor layer in base_model.layers[143:]:layer.trainable = True

Old Version: You have two options:

  1. Use tf.keras.layers.Reshape((2,2,512)) and reshape (None, 2048) -> (None, 2, 2 ,512). (But in the ResNet50V2, we have CNN why do you need more CNN!)
  2. Pass output of ResNet50V2 to tf.keras.layers.Flatten(). You can try like below:

Network for option_1:

model.add(base_model)
model.add(tf.keras.layers.Reshape((2,2,512)))
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
model.add(tf.keras.layers.MaxPooling2D((2, 2), padding='same'))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
model.add(tf.keras.layers.MaxPooling2D((2, 2), padding='same'))
model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
model.add(tf.keras.layers.MaxPooling2D((2, 2), padding='same'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=.2))    
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=.2))
model.add(tf.keras.layers.Dense(10, activation='softmax'))

Full code and Network for option_2:

import tensorflow as tf
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0base_model = tf.keras.applications.ResNet50(weights="imagenet", include_top=False,pooling="avg", input_shape=(32,32,3))base_model.trainable = Falsemodel = tf.keras.Sequential()
model.add(base_model)
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=.2))    
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=.2))
model.add(tf.keras.layers.Dense(10, activation='softmax'))        
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),optimizer='Adam', metrics=['accuracy'])
model.fit(X_train, y_train, batch_size=256, epochs=2, validation_split=.2)

Output:

Epoch 1/2
157/157 [==============================] - 10s 47ms/step - loss: 2.2472 - accuracy: 0.1683 - val_loss: 2.0121 - val_accuracy: 0.2772
Epoch 2/2
157/157 [==============================] - 6s 40ms/step - loss: 2.0074 - accuracy: 0.2566 - val_loss: 1.9161 - val_accuracy: 0.2934
https://en.xdnf.cn/q/119358.html

Related Q&A

How to get missing date in columns using python pandas [closed]

Closed. This question needs details or clarity. It is not currently accepting answers.Want to improve this question? Add details and clarify the problem by editing this post.Closed 3 years ago.Improve…

vigenere cipher - not adding correct values

I want to get specific values from a for loop to add to another string to create a vigenere cipher.heres the code.userinput = input(enter message) keyword = input(enter keyword) new = for a in keyword…

Why isnt my output returning as expected?

So I wrote this code def diagsDownRight(M):n = len(M)m = [[] * (n - i - 1) + row + [] * i for i, row in enumerate(M)]return ([.join(col) for col in zip(*m)]), [.join(col[::-1]) for col in zip(*m)] def …

Django Stripe payment does not respond after clicking the Submit Payment button

I have an e-commerce application that Im working on. The app is currently hosted on Heroku free account. At the moment I can select a product, add it on the cart and can get up to the stripe form and t…

get file path using backslash (\) in windows in python [duplicate]

This question already has answers here:How can I put an actual backslash in a string literal (not use it for an escape sequence)?(4 answers)Closed 2 years ago.How to get result exactly the same format…

Printing progress bar on a console without the use of for -loop

I have a script written in python, where I have a statement:Process.open() //some parametersWhich executes a command and puts the output on the console ,where I do not know the time taken to execute t…

ModuleNotFoundError: No module named verovio

Hi there I would like to run my flask app in a container but I got stucked caused of a third party module. (I am using PyCharm)This is my docker file:FROM python:3-alpineMAINTAINER fooCOPY app /appWORK…

Python: TypeError: list object is not callable on global variable

I am currently in the process of programming a text-based adventure in Python as a learning exercise. I want "help" to be a global command, stored as values in a list, that can be called at (…

Python beautifulsoup how to get the line after href

I have this piece of html:<a href="http://francetv.fr/videos/alcaline_l_instant_,12163184.html" class="ss-titre">"Paris Combo" </a> <…

Scrapy empty output

I am trying to use Scrapy to extract data from page. But I get an empty output. What is the problem? spider: class Ratemds(scrapy.Spider):name = ratemdsallowed_domains = [ratemds.com]custom_settings =…