Cannot take the length of Shape with unknown rank

2024/10/1 23:44:59

I have a neural network, from a tf.data data generator and a tf.keras model, as follows (a simplified version-because it would be too long):

dataset = ...

A tf.data.Dataset object that with the next_x method calls the get_next for the x_train iterator and for the next_y method calls the get_next for the y_train iterator. Each label is a (1, 67) array in one-hot form.

Layers:

input_tensor = tf.keras.layers.Input(shape=(240, 240, 3))  # dim of x
output = tf.keras.layers.Flatten()(input_tensor)
output= tf.keras.Dense(67, activation='softmax')(output)  # 67 is the number of classes

Model:

model = tf.keras.models.Model(inputs=input_tensor, outputs=prediction)
model.compile(optimizer=tf.train.AdamOptimizer(), loss=tf.losses.softmax_cross_entropy, metrics=['accuracy'])
model.fit_generator(gen(dataset.next_x(), dataset.next_y()), steps_per_epochs=100)

gen is defined like this:

def gen(x, y):while True:yield(x, y)

My problem is that when I try to run it, I get an error in the model.fit part:

ValueError: Cannot take the length of Shape with unknown rank.

Any ideas are appreciated!

Answer

Could you post a longer stack-trace? I think your problem might be related to this recent tensorflow issue:

https://github.com/tensorflow/tensorflow/issues/24520

There's also a simple PR that fixes it (not yet merged). Maybe try it out yourself?

EDIT

Here is the PR: open tensorflow/python/keras/engine/training_utils.py

replace the following (line 232 at the moment):

  if (x.shape is not Noneand len(x.shape) == 1

with this:

  if tensor_util.is_tensor(x):x_shape_ndims = x.shape.ndims if x.shape is not None else Noneelse:x_shape_ndims = len(x.shape)if (x_shape_ndims == 1
https://en.xdnf.cn/q/70910.html

Related Q&A

Pre-fill new functions in Eclipse and Pydev with docstring and Not Implemented exception

I am editing my Python source code with Eclipse and Pydev.I want to document all of my functions and raise a "Not Implemented" exception whenever a function have not yet been implemented. For…

How to serialize hierarchical relationship in Django REST

I have a Django model that is hierarchical using django-mptt, which looks like:class UOMCategory(MPTTModel, BaseModel):"""This represents categories of different unit of measurements.&qu…

Django: Loading another template on click of a button

Ive been working on a django project for a few weeks now, just playing around so that I can get the hang of it. I am a little bit confused. I have a template now called "home.html". I was wo…

Given two python lists of same length. How to return the best matches of similar values?

Given are two python lists with strings in them (names of persons):list_1 = [J. Payne, George Bush, Billy Idol, M Stuart, Luc van den Bergen] list_2 = [John Payne, George W. Bush, Billy Idol, M. Stuart…

Extracting Javascript gettext messages using Babel CLI extractor

It is stated here that Babel can extract gettext messages for Python and Javascript files.Babel comes with a few builtin extractors: python (which extractsmessages from Python source files), javascript…

Getting TTFB (time till first byte) for an HTTP Request

Here is a python script that loads a url and captures response time:import urllib2 import timeopener = urllib2.build_opener() request = urllib2.Request(http://example.com)start = time.time() resp = ope…

accessing kubernetes python api through a pod

so I need to connect to the python kubernetes client through a pod. Ive been trying to use config.load_incluster_config(), basically following the example from here. However its throwing these errors. …

Understanding DictVectorizer in scikit-learn?

Im exploring the different feature extraction classes that scikit-learn provides. Reading the documentation I did not understand very well what DictVectorizer can be used for? Other questions come to …

Parsing RSS with Elementtree in Python

How do you search for namespace-specific tags in XML using Elementtree in Python?I have an XML/RSS document like:<?xml version="1.0" encoding="UTF-8"?> <rss version=&quo…

String module object has no attribute join

So, I want to create a user text input box in Pygame, and I was told to look at a class module called inputbox. So I downloaded inputbox.py and imported into my main game file. I then ran a function in…