Connecting Keras models / replacing input but keeping layers

2024/10/16 0:18:55

This questions is similar to Keras replacing input layer.

I have a classifier network and an autoencoder network and I want to use the output of the autoencoder (i.e. encoding + decoding, as a preprocessing step) as the input to the classifier - but after the classifier was already trained on the regular data.

The classification network was built with the functional API like this (based on this example):

clf_input = Input(shape=(28,28,1))
clf_layer = Conv2D(...)(clf_input)
clf_layer = MaxPooling2D(...)(clf_layer)
...
clf_output = Dense(num_classes, activation='softmax')(clf_layer)
model = Model(clf_input, clf_output)
model.compile(...)
model.fit(...)

And the autoencoder like this (based on this example):

ae_input = Input(shape=(28,28,1))
x = Conv2D(...)(ae_input)
x = MaxPooling2D(...)(x)
...
encoded = MaxPooling2D(...)(x)
x = Conv2d(...)(encoded)
x = UpSampling2D(...)(x)
...
decoded = Conv2D(...)(x)
autoencoder = Model(ae_input, decoded)
autoencoder.compile(...)
autoencoder.fit(...)

I can concatenate the two models like this (I still need the original models, hence the copying):

model_copy = keras.models.clone_model(model)
model_copy.set_weights(model.get_weights())
# remove original input layer
model_copy.layers.pop(0)
# set the new input
new_clf_output = model_copy(decoded)
# get the stacked model
stacked_model = Model(ae_input, new_clf_output)
stacked_model.compile(...)

And this works great when all I want to do is apply the model to new test data, but it gives an error on something like this:

for layer in stacked_model.layers:print layer.get_config()

where it gets to the end of the autoencoder but then fails with a KeyError at the point where the classifier model gets its input. Also when plotting the model with keras.utils.plot_model I get this:

stacked_model

where you can see the autoencoder layers but then at the end, instead of the individual layers from the classifier model, there is only the complete model in one block.

Is there a way to connect two models such the new stacked model is actually made up of all the individual layers?

Answer

Ok, what I could come up with is to really manually go through each layer of the model and reconnect them one by one again like this:

l = model.layers[1](decoded)  # layer 0 is the input layer, which we're replacing
for i in range(2, len(model.layers)):l = model.layers[i](l)
stacked_model = Model(ae_input, l)
stacked_model.compile(...)

while this works and produces the correct plot and no errors, this does not seem like the most elegant solution...

(btw, the copying of the model actually seems to be unnecessary as I'm not retraining anything.)

https://en.xdnf.cn/q/69223.html

Related Q&A

PySpark 2.x: Programmatically adding Maven JAR Coordinates to Spark

The following is my PySpark startup snippet, which is pretty reliable (Ive been using it a long time). Today I added the two Maven Coordinates shown in the spark.jars.packages option (effectively "…

Python: How to create simple web pages without a huge framework? [closed]

As it currently stands, this question is not a good fit for our Q&A format. We expect answers to be supported by facts, references, or expertise, but this question will likely solicit debate, argum…

AttributeError: module MySQLdb.constants.FIELD_TYPE has no attribute JSON while migrating in Django

I do not know in what way solve this error. Any hints? I have simple Django projects and receive this error when try to do python3 manage.py migrate. This is related to any programming error in app or…

Downloading file using IE from python

Im trying to download file with Python using IE:from win32com.client import DispatchWithEventsclass EventHandler(object):def OnDownloadBegin(self):passie = DispatchWithEvents("InternetExplorer.App…

Good resources to start python for web development?

Im really interested in learning Python for web development. Can anyone point me in the right direction? Ive been looking at stuff on Google, but havent really found anything that shows proper documen…

django file upload: [Errno 13] Permission denied: /static

I am trying to upload several files in django. On my local maching where I use the djangos build in server everything works fine but on my productivity server I get this error:[Errno 13] Permission den…

efficient way to change the header of a file in Python

I am trying to write a python script to update the header (only the first line) of some huge files, but as the new header is not necessary to be the same size (in bytes) as the original one, is there a…

Converting a numpy array of dtype objects to dtype complex

I have a numpy array which I want to convert from an object to complex. If I take that array as dtype string and convert it, there is no problem:In[22]: bane Out[22]: array([1.000027337501943-7.3310852…

Python ZeroMQ PUSH/PULL -- Lost Messages?

I am trying to use python with zeroMQ in PUSH / PULL mode, sending messages of size 4[MB] every few seconds.For some reason, while it looks like all the messages are sent, ONLY SOME of them appear to h…

Using object as key in dictionary in Python - Hash function

I am trying to use an object as the key value to a dictionary in Python. I follow the recommendations from some other posts that we need to implement 2 functions: __hash__ and __eq__ And with that, I a…