How can I use the index array in tensorflow?

2024/10/12 0:30:50

If given a matrix a with shape (5,3) and index array b with shape (5,), we can easily get the corresponding vector c through,

c = a[np.arange(5), b]

However, I cannot do the same thing with tensorflow,

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]

Traceback (most recent call last): File "", line 1, inFile"~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py",line 513, in _SliceHelpername=name)

File"~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py",line 671, in strided_sliceshrink_axis_mask=shrink_axis_mask) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py",line 3688, in strided_sliceshrink_axis_mask=shrink_axis_mask, name=name) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py",line 763, in apply_opop_def=op_def) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py",line 2397, in create_opset_shapes_for_outputs(ret) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py",line 1757, in set_shapes_for_outputsshapes = shape_func(op) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py",line 1707, in call_with_requiringreturn call_cpp_shape_fn(op, require_shape_fn=True) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py",line 610, in call_cpp_shape_fndebug_python_shape_fn, require_shape_fn) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py",line 675, in _call_cpp_shape_fn_implraise ValueError(err.message) ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_14' (op: 'StridedSlice') with inputshapes: [5,3], [2,5], [2,5], [2].

My question is, if I cannot produce the expected result in tensorflow as in numpy using the above mentioned method, what should I do?

Answer

This feature is not currently implemented in TensorFlow. GitHub issue #4638 is tracking the implementation of NumPy-style "advanced" indexing. However, you can use the tf.gather_nd() operator to implement your program:

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))row_indices = tf.range(5)# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])c = tf.gather_nd(a, indices)
https://en.xdnf.cn/q/69712.html

Related Q&A

cython with array of pointers

I have a list of numpy.ndarrays (with different length) in python and need to have very fast access to those in python. I think an array of pointers would do the trick. I tried:float_type_t* list_of_ar…

How to skip blank lines with read_fwf in pandas?

I use pandas.read_fwf() function in Python pandas 0.19.2 to read a file fwf.txt that has the following content:# Column1 Column2123 abc456 def# #My code is the following:import pandas as pd fil…

Pandas rolling std yields inconsistent results and differs from values.std

Using pandas v1.0.1 and numpy 1.18.1, I want to calculate the rolling mean and std with different window sizes on a time series. In the data I am working with, the values can be constant for some subse…

How to change attributes of a networkx / matplotlib graph drawing?

NetworkX includes functions for drawing a graph using matplotlib. This is an example using the great IPython Notebook (started with ipython3 notebook --pylab inline):Nice, for a start. But how can I in…

Deploying MLflow Model without Conda environment

Currently working on deploying my MLflow Model in a Docker container. The Docker container is set up with all the necessary dependencies for the model so it seems redundant for MLflow to also then crea…

Insert Data to SQL Server Table using pymssql

I am trying to write the data frame into the SQL Server Table. My code:conn = pymssql.connect(host="Dev02", database="DEVDb") cur = conn.cursor() query = "INSERT INTO dbo.SCORE…

module object has no attribute discover_devices

Im trying to get Pybluez to work for me. Here is what happens when I try to discover bluetooth devises. import bluetooth nearby_devices = bluetooth.discover_devices()Traceback (most recent call last):F…

scipy sparse matrix: remove the rows whose all elements are zero

I have a sparse matrix which is transformed from sklearn tfidfVectorier. I believe that some rows are all-zero rows. I want to remove them. However, as far as I know, the existing built-in functions, e…

Time complexity for adding elements to list vs set in python

Why does adding elements to a set take longer than adding elements to a list in python? I created a loop and iterated over 1000000 elements added it to a list and a set. List is consistently taking ar…

ERROR: Could not install packages due to an EnvironmentError: [Errno 28] No space left on device

I was trying to install turicreate using pip install -U turicreate But got the error Could not install packages due to an EnvironmentError: [Errno 28] Nospace left on device.I followed all the steps on…