BERT get sentence embedding

2024/9/22 14:30:39

I am replicating code from this page. I have downloaded the BERT model to my local system and getting sentence embedding.

I have around 500,000 sentences for which I need sentence embedding and it is taking a lot of time.

  1. Is there a way to expedite the process?
  2. Would sending batches of sentences rather than one sentence at a time help?

.

#!pip install transformers
import torch
import transformers
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',output_hidden_states = True, # Whether the model returns all hidden-states.)# Put the model in "evaluation" mode, meaning feed-forward operation.
model.eval()corpa=["i am a boy","i live in a city"]storage=[]#list to store all embeddingsfor text in corpa:# Add the special tokens.marked_text = "[CLS] " + text + " [SEP]"# Split the sentence into tokens.tokenized_text = tokenizer.tokenize(marked_text)# Map the token strings to their vocabulary indeces.indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)segments_ids = [1] * len(tokenized_text)tokens_tensor = torch.tensor([indexed_tokens])segments_tensors = torch.tensor([segments_ids])# Run the text through BERT, and collect all of the hidden states produced# from all 12 layers. with torch.no_grad():outputs = model(tokens_tensor, segments_tensors)# Evaluating the model will return a different number of objects based on # how it's  configured in the `from_pretrained` call earlier. In this case, # becase we set `output_hidden_states = True`, the third item will be the # hidden states from all layers. See the documentation for more details:# https://huggingface.co/transformers/model_doc/bert.html#bertmodelhidden_states = outputs[2]# `hidden_states` has shape [13 x 1 x 22 x 768]# `token_vecs` is a tensor with shape [22 x 768]token_vecs = hidden_states[-2][0]# Calculate the average of all 22 token vectors.sentence_embedding = torch.mean(token_vecs, dim=0)storage.append((text,sentence_embedding))

######update 1

I modified my code based upon the answer provided. It is not doing full batch processing

#!pip install transformers
import torch
import transformers
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',output_hidden_states = True, # Whether the model returns all hidden-states.)# Put the model in "evaluation" mode, meaning feed-forward operation.
model.eval()batch_sentences = ["Hello I'm a single sentence","And another sentence","And the very very last one"]
encoded_inputs = tokenizer(batch_sentences)storage=[]#list to store all embeddings
for i,text in enumerate(encoded_inputs['input_ids']):tokens_tensor = torch.tensor([encoded_inputs['input_ids'][i]])segments_tensors = torch.tensor([encoded_inputs['attention_mask'][i]])print (tokens_tensor)print (segments_tensors)# Run the text through BERT, and collect all of the hidden states produced# from all 12 layers. with torch.no_grad():outputs = model(tokens_tensor, segments_tensors)# Evaluating the model will return a different number of objects based on # how it's  configured in the `from_pretrained` call earlier. In this case, # becase we set `output_hidden_states = True`, the third item will be the # hidden states from all layers. See the documentation for more details:# https://huggingface.co/transformers/model_doc/bert.html#bertmodelhidden_states = outputs[2]# `hidden_states` has shape [13 x 1 x 22 x 768]# `token_vecs` is a tensor with shape [22 x 768]token_vecs = hidden_states[-2][0]# Calculate the average of all 22 token vectors.sentence_embedding = torch.mean(token_vecs, dim=0)print (sentence_embedding[:10])storage.append((text,sentence_embedding))

I could update first 2 lines from the for loop to below. But they work only if all sentences have same length after tokenization

tokens_tensor = torch.tensor([encoded_inputs['input_ids']])
segments_tensors = torch.tensor([encoded_inputs['attention_mask']])

moreover in that case outputs = model(tokens_tensor, segments_tensors) fails.

How could I fully perform batch processing in such case?

Answer

One of the easiest methods which can accelerate your workflow is batch data processing. In the current implementation, you are feeding ONLY one sentence at each iteration but there is a capability to use batched data!

Now if you are willing to implement this part yourself I highly recommend using tokenizer in this way to prepare your data.

batch_sentences = ["Hello I'm a single sentence","And another sentence","And the very very last one"]
encoded_inputs = tokenizer(batch_sentences)
print(encoded_inputs)
{'input_ids': [[101, 8667, 146, 112, 182, 170, 1423, 5650, 102],[101, 1262, 1330, 5650, 102],[101, 1262, 1103, 1304, 1304, 1314, 1141, 102]],'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]],'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]}

But there is a simpler approach, using FeatureExtractionPipeline with comprehensive documentation! This would look like this:

from transformers import pipelinefeature_extraction = pipeline('feature-extraction', model="distilroberta-base", tokenizer="distilroberta-base")
features = feature_extraction(["Hello I'm a single sentence","And another sentence","And the very very last one"])

UPDATE1 In fact, you changed your code slightly but you're passing samples one at a time yet, not in the batch form. If we want to stick to your implementation batch processing would be something like this:

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',output_hidden_states = True, # Whether the model returns all hidden-states.)
model.eval()
sentences = [ "Hello I'm a single sentence","And another sentence","And the very very last one","Hello I'm a single sentence","And another sentence","And the very very last one","Hello I'm a single sentence","And another sentence","And the very very last one",]
batch_size = 4  
for idx in range(0, len(sentences), batch_size):batch = sentences[idx : min(len(sentences), idx+batch_size)]# encoded = tokenizer(batch)encoded = tokenizer.batch_encode_plus(batch,max_length=50, padding='max_length', truncation=True)encoded = {key:torch.LongTensor(value) for key, value in encoded.items()}with torch.no_grad():outputs = model(**encoded)print(outputs.last_hidden_state.size())

output:

torch.Size([4, 50, 768]) # batch_size * max_length * hidden dim
torch.Size([4, 50, 768])
torch.Size([1, 50, 768]) 

UPDATE2

There are two questions about what has been mentioned about padding the batched-data to maximum length. One, is it able to distrubting the transformer model with irrelevant information? NO, because in the training phase the model has presented with variable-length input sentences in the batched form, and designers have introduced a specific parameter to guide the model on WHERE it should attention! Second, how can you get rid of this garbage data? Using the attention mask parameter you can perform the mean operation only on relevant data!

So the code would be changed to something like this:

for idx in range(0, len(sentences), batch_size):batch = sentences[idx : min(len(sentences), idx+batch_size)]# encoded = tokenizer(batch)encoded = tokenizer.batch_encode_plus(batch,max_length=50, padding='max_length', truncation=True)encoded = {key:torch.LongTensor(value) for key, value in encoded.items()}with torch.no_grad():outputs = model(**encoded)lhs = outputs.last_hidden_stateattention = encoded['attention_mask'].reshape((lhs.size()[0], lhs.size()[1], -1)).expand(-1, -1, 768)embeddings = torch.mul(lhs, attention)denominator = torch.count_nonzero(embeddings, dim=1)summation = torch.sum(embeddings, dim=1)mean_embeddings = torch.div(summation, denominator)
https://en.xdnf.cn/q/71937.html

Related Q&A

Python Subversion wrapper library

In Subversions documentation theres an example of using Subversion from Python#!/usr/bin/python import svn.fs, svn.core, svn.reposdef crawl_filesystem_dir(root, directory):"""Recursively…

How to convert a selenium webelement to string variable in python

from selenium import webdriver from time import sleep from selenium.common.exceptions import NoSuchAttributeException from selenium.common.exceptions import NoSuchElementException from selenium.webdriv…

Why are session methods unbound in sqlalchemy using sqlite?

Code replicating the error:from sqlalchemy import create_engine, Table, Column, Integer from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmakerBase = declarative…

Combining Tkinter and win32ui makes Python crash on exit

While building a basic app using the winapi with Python 2.7 (Im on Windows 8.1), I tried to add a small Tkinter gui to the program. The problem is, whenever I close the app window, Python crashes compl…

horizontal tree with graphviz_layout

in python, with networkx. I can plot a vertical tree with : g=nx.balanced_tree(2,4)pos = nx.graphviz_layout(g, prog=dot)nx.draw(g,pos,labels=b_all, node_size=500)plt.show()similar to [root]|| |nod…

Finding first n primes? [duplicate]

This question already has answers here:Closed 12 years ago.Possible Duplicate:Fastest way to list all primes below N in python Although I already have written a function to find all primes under n (pr…

Scipy.optimize.root does not converge in Python while Matlab fsolve works, why?

I am trying to find the root y of a function called f using Python. Here is my code:def f(y):w,p1,p2,p3,p4,p5,p6 = y[:7] t1 = w - 0.99006633*(p1**0.5) - (-1.010067)*((1-p1))t2 = w - 22.7235687*(p2**0.…

Query CPU ID from Python?

How I can find processor id with py2.6, windows OS?I know that there is pycpuid, but I cant compile this under 2.6.

Recover from segfault in Python

I have a few functions in my code that are randomly causing SegmentationFault error. Ive identified them by enabling the faulthandler. Im a bit stuck and have no idea how to reliably eliminate this pro…

python and using self in methods

From what I read/understand, the self parameter is similiar to this.Is that true?If its optional, what would you do if self wasnt passed into the method?