Can Keras model.predict return a dictionary?

2024/10/8 22:19:50

The documentation https://keras.io/models/model/#predict says that model.predict returns Numpy array(s) of predictions. In the Keras API, is there is a way to distinguishing which of these arrays are which? How about in the TF implementation?

At the top of the same page of documentation, they say that "models can specify multiple inputs and outputs using lists". It seems that nothing breaks if instead, one passes dictionaries:

my_model = tf.keras.models.Model(inputs=my_inputs_dict, outputs=my_outputs_dict)

When calling model.fit the same documentation says "If input layers in the model are named, you can also pass a dictionary mapping input names to Numpy arrays."

It would be nice if either the keys from my_output_dict or the names of the dictionary values (layers) in my_output_dict were attached to the outputs of my_model.predict(...)

If I save the model to TensorFlow's saved_model format protobuf using tf.keras.model.save the tf.serving API works this way-- with named inputs and outputs...

Answer

Use my_model.output_names

Given

my_model = tf.keras.models.Model(inputs=my_inputs_dict, outputs=my_outputs_dict)

create the dict yourself from my_model.output_names, which is a list of name attributes of your output layers in the order of prediction

prediction_list = my_model.predict(my_test_input_dict)
prediction_dict = {name: pred for name, pred in zip(my_model.output_names, prediction_list)}
https://en.xdnf.cn/q/70094.html

Related Q&A

Flask OIDC: oauth2client.client.FlowExchangeError

The Problem: The library flask-oidc includes the scope parameter into the authorization-code/access-token exchange request, which unsurprisingly throws the following error:oauth2client.client.FlowExcha…

Cumulative count at a group level Python

I have a pandas dataframe like this : df = pd.DataFrame([[A, 1234, 20120201],[A, 1134, 20120201],[A, 1011, 20120201],[A, 1123, 20121004],[A, 1111, 20121004],[A, 1224, 20121105],[B, 1156, 20120403],[B, …

Easiest ways to generate graphs from Python? [closed]

Closed. This question is seeking recommendations for books, tools, software libraries, and more. It does not meet Stack Overflow guidelines. It is not currently accepting answers.We don’t allow questi…

Stripping python namespace attributes from an lxml.objectify.ObjectifiedElement [duplicate]

This question already has answers here:Closed 11 years ago.Possible Duplicate:When using lxml, can the XML be rendered without namespace attributes? How can I strip the python attributes from an lxml…

matplotlib xkcd and black figure background

I am trying to make a plot using matplotlibs xkcd package while having a black background. However, xkcd seems to add a sort of white contour line around text and lines. On a white background you cant …

Python: Whats the difference between set.difference and set.difference_update?

s.difference(t) returns a new set with no elements in t.s.difference_update(t) returns an updated set with no elements in t.Whats the difference between these two set methods? Because the difference_u…

python telebot got unexpected response

I have been using my Telegram bot for sending me different notifications from my desktop computer using pythons telebot library. Everything was working properly for quite a long time, but one day it st…

How to set correct value for Django ROOT_URLCONF setting in different branches

Ive put site directory created by django-admin startproject under version control (Mercurial). Lets say, the site is called frobnicator.Now I want to make some serious refactoring, so I clone the site …

How do I improve scrapys download speed?

Im using scrapy to download pages from many different domains in parallel. I have hundreds of thousands of pages to download, so performance is important.Unfortunately, as Ive profiled scrapys speed, …

Convert numpy, list or float to string in python

Im writing a python function to append data to text file, as shown in the following,The problem is the variable, var, could be a 1D numpy array, a 1D list, or just a float number, I know how to convert…