How to use tf.nn.embedding_lookup_sparse in TensorFlow?

2024/9/18 13:19:13

We have tried using tf.nn.embedding_lookup and it works. But it needs dense input data and now we need tf.nn.embedding_lookup_sparse for sparse input.

I have written the following code but get some errors.

import tensorflow as tf
import numpy as npexample1 = tf.SparseTensor(indices=[[4], [7]], values=[1, 1], shape=[10])
example2 = tf.SparseTensor(indices=[[3], [6], [9]], values=[1, 1, 1], shape=[10])vocabulary_size = 10
embedding_size = 1
var = np.array([0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0])
#embeddings = tf.Variable(tf.ones([vocabulary_size, embedding_size]))
embeddings = tf.Variable(var)embed = tf.nn.embedding_lookup_sparse(embeddings, example2, None)with tf.Session() as sess:sess.run(tf.initialize_all_variables())print(sess.run(embed))

The error log looks like this.

enter image description here

Now I have no idea how to fix and use this method correctly. Any comment could be appreciated.

After diving into safe_embedding_lookup_sparse's unit test, I'm more confused why I got this result if giving the sparse weights, especially why we got something like embedding_weights[0][3] where 3 is not appeared in the code above.

enter image description here

Answer

tf.nn.embedding_lookup_sparse() uses Segmentation to combine embeddings, which requires indices from SparseTensor to start at 0 and to be increasing by 1. That's why you get this error.

Instead of boolean values, your sparse tensor needs to hold only the indices of every row that you want to retrieve from embeddings. Here's your tweaked code:

import tensorflow as tf
import numpy as npexample = tf.SparseTensor(indices=[[0], [1], [2]], values=[3, 6, 9], dense_shape=[3])vocabulary_size = 10
embedding_size = 1
var = np.array([0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0])
embeddings = tf.Variable(var)embed = tf.nn.embedding_lookup_sparse(embeddings, example, None)with tf.Session() as sess:sess.run(tf.initialize_all_variables())print(sess.run(embed)) # prints [  9.  36.  81.]

In addition, you can use indices from tf.SparseTensor() to combine word embeddings using one of the allowed tf.nn.embedding_lookup_sparse() combiners:

  • "sum" computes the weighted sum of the embedding results for each row.
  • "mean" is the weighted sum divided by the total weight.
  • "sqrtn" is the weighted sum divided by the square root of the sum of the squares of the weights.

For example:

example = tf.SparseTensor(indices=[[0], [0]], values=[1, 2], dense_shape=[2])
...
embed = tf.nn.embedding_lookup_sparse(embeddings, example, None, combiner='sum')
...
print(sess.run(embed)) # prints [ 5.]
https://en.xdnf.cn/q/73329.html

Related Q&A

Saving Python SymPy figures with a specific resolution/pixel density

I am wondering if there is a way to change the pixel density/resolution of sympy plots. For example, lets consider the simple code snippet below:import sympy as sypx = syp.Symbol(x) miles_to_km = x * 1…

Matplotlib boxplot width in log scale

I am trying to plot a boxplot with logarithmic x-axis. As you can see on the example below width of each box decreases because of the scale. Is there any way to make the width of all boxes same?

How to enable and disable Intel MKL in numpy Python?

I want to test and compare Numpy matrix multiplication and Eigen decomposition performance with Intel MKL and without Intel MKL. I have installed MKL using pip install mkl (Windows 10 (64-bit), Python …

getdefaultlocale returning None when running sync.db on Django project in PyCharm

OSX 10.7.3, PyCharm version 2.5 build PY 117.200Ill run through how I get the error:I start a new project Create a new VirtualEnv and select Python 2.7 as my base interpreter (leave inherit global pack…

Redirecting an old URL to a new one with Flask micro-framework

Im making a new website to replace a current one, using Flask micro-framework (based on Werkzeug) which uses Python (2.6 in my case).The core functionality and many pages are the same. However by using…

python decimals - rounding to nearest whole dollar (no cents) - with ROUND_HALF_UP

Im trying to use Decimal.quantize() to achieve the following: -For any amount of money, expressed as a python decimal of default precision, I want to round it using decimal.ROUND_HALF_UP so that it has…

How to use pytest fixtures in a decorator without having it as argument on the decorated function

I was trying to use a fixture in a decorator which is intended to decorate test functions. The intention is to provide registered test data to the test. There are two options:Automatic import Manual im…

Including Python standard libraries in your distribution [closed]

Closed. This question does not meet Stack Overflow guidelines. It is not currently accepting answers.This question does not appear to be about programming within the scope defined in the help center.Cl…

Using watchdog of python to monitoring afp shared folder from linux

I want linux machine(Raspberry pi) to monitor a shared folder by AFP(Apple file protocol, macbook is host).I can mount shared folder by mount_afp, and installed watchdog python library to monitor a sha…

Fitting curve: why small numbers are better?

I spent some time these days on a problem. I have a set of data:y = f(t), where y is very small concentration (10^-7), and t is in second. t varies from 0 to around 12000.The measurements follow an est…