I'm new to TensorFlow and would like to read a comma separated values (csv) file, containing 2 columns, column 1 the index, and column 2 a label string. I have the following code which reads lines in the csv file line by line and I am able to get the data in the csv file correctly using print statements. However, I would like to do one-hot encoding conversion from the string labels and do not how to do it in TensorFlow. The final goal is to use the tf.train.batch() function so I can get batches of one-hot label vectors to train a neural network.
As you can see in the code below, I can create a one-hot vector for each of the label entries manually within a TensorFlow session. But how do I use the tf.train.batch() function? If I move the line
label_batch = tf.train.batch([col2], batch_size=5)
into the TensorFlow session block (replacing col2 with label_one_hot), the program blocks doing nothing. I tried to move the one-hot vector conversion outside the TensorFlow session but I failed to get it to work correctly. What is the correct way to do it? Please help.
label_files = []
label_files.append(LABEL_FILE)
print "label_files: ", label_filesfilename_queue = tf.train.string_input_producer(label_files)reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
print "key:", key, ", value:", valuerecord_defaults = [['default_id'], ['default_label']]
col1, col2 = tf.decode_csv(value, record_defaults=record_defaults)num_lines = sum(1 for line in open(LABEL_FILE))label_batch = tf.train.batch([col2], batch_size=5)with tf.Session() as sess:coordinator = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coordinator)for i in range(100):column1, column2 = sess.run([col1, col2])index = 0if column2 == 'airplane':index = 0elif column2 == 'automobile':index = 1elif column2 == 'bird':index = 2elif column2 == 'cat':index = 3elif column2 == 'deer':index = 4elif column2 == 'dog':index = 5elif column2 == 'frog':index = 6elif column2 == 'horse':index = 7elif column2 == 'ship':index = 8elif column2 == 'truck':index = 9label_one_hot = tf.one_hot([index], 10) # depth=10 for 10 categoriesprint "column1:", column1, ", column2:", column2# print "onehot label:", sess.run([label_one_hot])print sess.run(label_batch)coordinator.request_stop()coordinator.join(threads)