Injecting pre-trained word2vec vectors into TensorFlow seq2seq

2024/10/15 3:21:06

I was trying to inject pretrained word2vec vectors into existing tensorflow seq2seq model.

Following this answer, I produced the following code. But it doesn't seem to improve performance as it should, although the values in the variable are updated.

In my understanding the error might be due to the fact that EmbeddingWrapper or embedding_attention_decoder create embeddings independently of the vocabulary order?

What would be the best way to load pretrained vectors into tensorflow model?

SOURCE_EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding"
TARGET_EMBEDDING_KEY = "embedding_attention_seq2seq/embedding_attention_decoder/embedding"def inject_pretrained_word2vec(session, word2vec_path, input_size, dict_dir, source_vocab_size, target_vocab_size):word2vec_model = word2vec.load(word2vec_path, encoding="latin-1")print("w2v model created!")session.run(tf.initialize_all_variables())assign_w2v_pretrained_vectors(session, word2vec_model, SOURCE_EMBEDDING_KEY, source_vocab_path, source_vocab_size)assign_w2v_pretrained_vectors(session, word2vec_model, TARGET_EMBEDDING_KEY, target_vocab_path, target_vocab_size)def assign_w2v_pretrained_vectors(session, word2vec_model, embedding_key, vocab_path, vocab_size):vectors_variable = [v for v in tf.trainable_variables() if embedding_key in v.name]if len(vectors_variable) != 1:print("Word vector variable not found or too many. key: " + embedding_key)print("Existing embedding trainable variables:")print([v.name for v in tf.trainable_variables() if "embedding" in v.name])sys.exit(1)vectors_variable = vectors_variable[0]vectors = vectors_variable.eval()with gfile.GFile(vocab_path, mode="r") as vocab_file:counter = 0while counter < vocab_size:vocab_w = vocab_file.readline().replace("\n", "")# for each word in vocabulary check if w2v vector exist and inject.# otherwise dont change the value.if word2vec_model.__contains__(vocab_w):w2w_word_vector = word2vec_model.get_vector(vocab_w)vectors[counter] = w2w_word_vectorcounter += 1session.run([vectors_variable.initializer],{vectors_variable.initializer.inputs[1]: vectors})
Answer

I am not familiar with the seq2seq example, but in general you can use the following code snippet to inject your embeddings:

Where you build you graph:

with tf.device("/cpu:0"):embedding = tf.get_variable("embedding", [vocabulary_size, embedding_size])      inputs = tf.nn.embedding_lookup(embedding, input_data)

When you execute (after building your graph and before stating the training), just assign your saved embeddings to the embedding variable:

session.run(tf.assign(embedding, embeddings_that_you_want_to_use))

The idea is that the embedding_lookup will replace input_data values with those present in the embedding variable.

https://en.xdnf.cn/q/69333.html

Related Q&A

MySQL Stored Procedures, Pandas, and Use multi=True when executing multiple statements

Note - as MaxU suggested below, the problem is specific to mysql.connector and does not occur if you use pymysql. Hope this saves someone else some headachesUsing Python, Pandas, and mySQL and cannot…

How can I change the font size in GTK?

Is there an easy way to change the font size of text elements in GTK? Right now the best I can do is do set_markup on a label, with something silly like:lbl.set_markup("<span font_desc=Tahoma …

How to read BigQuery table using python pipeline code in GCP Dataflow

Could someone please share syntax to read/write bigquery table in a pipeline written in python for GCP Dataflow

How can I wrap a python function in a way that works with with inspect.signature?

Some uncontroversial background experimentation up front: import inspectdef func(foo, bar):passprint(inspect.signature(func)) # Prints "(foo, bar)" like youd expectdef decorator(fn):def _wra…

Python OpenCV Error: TypeError: Image data cannot be converted to float

So I am trying to create a Python Program to detect similar details in two images using Pythons OpenCV. I have the two images and they are in my current directory, and they exist (see the code in line…

Specify timestamp on each packet in Scapy?

With Scapy, when I create a packet and write it to a pcap file, it sets the timestamp of the packet to the current time.This is my current usage. 1335494712.991895 being the time I created the packet:&…

Converting a dataframe to dictionary with multiple values

I have a dataframe likeSr.No ID A B C D1 Tom Earth English BMW2 Tom Mars Spanish BMW Green 3 Michael Mercury Hindi …

How do I create KeyPoints to compute SIFT?

I am using OpenCV-Python.I have identified corner points using cv2.cornerHarris. The output is of type dst.I need to compute SIFT features of the corner points. The input to sift.compute() has to be of…

Error in Tensorboards(PyTorch) add_graph

Im following this Pytorchs Tensorboard documentation. I have the following code: model = torchvision.models.resnet50(False) writer.add_graph(model)It throws the following error:_ = model(*args) # dont…

Population must be a sequence or set. For dicts, use list(d)

I try to excute this code and I get the error bellow, I get the error in the random function and I dont know how to fix it, please help me.def load_data(sample_split=0.3, usage=Training, to_cat=True, v…