How to convert string labels to one-hot vectors in TensorFlow?

2024/9/24 9:27:20

I'm new to TensorFlow and would like to read a comma separated values (csv) file, containing 2 columns, column 1 the index, and column 2 a label string. I have the following code which reads lines in the csv file line by line and I am able to get the data in the csv file correctly using print statements. However, I would like to do one-hot encoding conversion from the string labels and do not how to do it in TensorFlow. The final goal is to use the tf.train.batch() function so I can get batches of one-hot label vectors to train a neural network.

As you can see in the code below, I can create a one-hot vector for each of the label entries manually within a TensorFlow session. But how do I use the tf.train.batch() function? If I move the line

label_batch = tf.train.batch([col2], batch_size=5)

into the TensorFlow session block (replacing col2 with label_one_hot), the program blocks doing nothing. I tried to move the one-hot vector conversion outside the TensorFlow session but I failed to get it to work correctly. What is the correct way to do it? Please help.

label_files = []
label_files.append(LABEL_FILE)
print "label_files: ", label_filesfilename_queue = tf.train.string_input_producer(label_files)reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
print "key:", key, ", value:", valuerecord_defaults = [['default_id'], ['default_label']]
col1, col2 = tf.decode_csv(value, record_defaults=record_defaults)num_lines = sum(1 for line in open(LABEL_FILE))label_batch = tf.train.batch([col2], batch_size=5)with tf.Session() as sess:coordinator = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coordinator)for i in range(100):column1, column2 = sess.run([col1, col2])index = 0if column2 == 'airplane':index = 0elif column2 == 'automobile':index = 1elif column2 == 'bird':index = 2elif column2 == 'cat':index = 3elif column2 == 'deer':index = 4elif column2 == 'dog':index = 5elif column2 == 'frog':index = 6elif column2 == 'horse':index = 7elif column2 == 'ship':index = 8elif column2 == 'truck':index = 9label_one_hot = tf.one_hot([index], 10)  # depth=10 for 10 categoriesprint "column1:", column1, ", column2:", column2# print "onehot label:", sess.run([label_one_hot])print sess.run(label_batch)coordinator.request_stop()coordinator.join(threads)
Answer

It's been more than 2 years since this question was asked, but this answer might still be relevant for some. Here's one simple way to transform string labels into one-hot vectors in TF:

import tensorflow as tfvocab = ['a', 'b', 'c']input = tf.placeholder(dtype=tf.string, shape=(None,))
matches = tf.stack([tf.equal(input, s) for s in vocab], axis=-1)
onehot = tf.cast(matches, tf.float32)with tf.Session() as sess:out = sess.run(onehot, feed_dict={input: ['c', 'a']})print(out) # prints [[0. 0. 1.]#         [1. 0. 0.]]
https://en.xdnf.cn/q/71718.html

Related Q&A

Pandas dataframe boolean mask on multiple columns

I have a dataframe (df) containing several columns with an actual measure and corresponding number of columns (A,B,...) with an uncertainty (dA, dB, ...) for each of these columns:A B dA dB …

Which of these scripting languages is more appropriate for pen-testing? [closed]

Closed. This question is opinion-based. It is not currently accepting answers.Want to improve this question? Update the question so it can be answered with facts and citations by editing this post.Clo…

Keras: Optimal epoch selection

Im trying to write some logic that selects the best epoch to run a neural network in Keras. My code saves the training loss and the test loss for a set number of epochs and then picks the best fitting …

error in loading pickle

Not able to load a pickle file. I am using python 3.5import pickle data=pickle.load(open("D:\\ud120-projects\\final_project\\final_project_dataset.pkl", "r"))TypeError: a bytes-lik…

How to test if a webpage is an image

Sorry that the title wasnt very clear, basically I have a list with a whole series of urls, with the intention of downloading the ones that are pictures. Is there anyway to check if the webpage is an i…

Generic detail view ProfileView must be called with either an object pk or a slug

Im new to Django 2.0 and im getting this error when visiting my profile page view. Its working with urls like path(users/<int:id>) but i wanted to urls be like path(<username>). Not sure wh…

Python Pandas group datetimes by hour and count row

This is my transaction dataframe, where each row mean a transaction :date station 30/10/2017 15:20 A 30/10/2017 15:45 A 31/10/2017 07:10 A 31/10/2017 07:25 B 31/10/2017 07:55 …

Get Bokehs selection in notebook

Id like to select some points on a plot (e.g. from box_select or lasso_select) and retrieve them in a Jupyter notebook for further data exploration. How can I do that?For instance, in the code below, …

Upload CSV file into Microsoft Azure storage account using python

I am trying to upload a .csv file into Microsoft Azure storage account using python. I have found C-sharp code to write a data to blob storage. But, I dont know C# language. I need to upload .csv file …

GeoDjango: How can I get the distance between two points?

My Profile model has this field:location = models.PointField(geography=True, dim=2, srid=4326)Id like to calculate the distance between the two of these locations (taking into account that the Earth is…