Using read_batch_record_features with an Estimator

2024/10/14 21:17:51

(I'm using tensorflow 1.0 and Python 2.7)

I'm having trouble getting an Estimator to work with queues. Indeed, if I use the deprecated SKCompat interface with custom data files and a given batch size, the model trains properly. I'm trying to use the new interface with an input_fn that batches features out of TFRecord files (equivalent to my custom data files). The scripts runs properly but the loss value doesn't change after 200 or 300 steps. It seems that the model is looping on a small input batch (this would explain why the loss converges so fast).

I have a 'run.py' script that looks like the following:

import tensorflow as tf
from tensorflow.contrib import learn, metrics#[...]
evalMetrics = {'accuracy':learn.MetricSpec(metric_fn=metrics.streaming_accuracy)}
runConfig = learn.RunConfig(save_summary_steps=10)
estimator = learn.Estimator(model_fn=myModel,params=myParams,modelDir='/tmp/myDir',config=runConfig)session = tf.Session(graph=tf.get_default_graph())with session.as_default():tf.global_variables_initializer()coordinator = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=session,coord=coordinator)estimator.fit(input_fn=lambda: inputToModel(trainingFileList),steps=10000)estimator.evaluate(input_fn=lambda: inputToModel(evalFileList),steps=10000,metrics=evalMetrics)coordinator.request_stop()coordinator.join(threads)
session.close()

My inputToModel function looks like this:

import tensorflow as tfdef inputToModel(fileList):features = {'rawData': tf.FixedLenFeature([100],tf.float32),'label': tf.FixedLenFeature([],tf.int64)}tensorDict = tf.contrib.learn.read_batch_record_features(fileList,batch_size=100,features=features,randomize_input=True,reader_num_threads=4,num_epochs=1,name='inputPipeline')tf.local_variables_initializer()data = tensorDict['rawData']labelTensor = tensorDict['label']inputTensor = tf.reshape(data,[-1,10,10,1])return inputTensor,labelTensor

Any help or suggestions is welcome !

Answer

Try to use: tf.global_variables_initializer().run()

I wanna do a similar thing but I do not know how to use Estimator API with multi-threading. There is an Experiment class for serving too - might be useful

delete line session = tf.Session(graph=tf.get_default_graph()) and session.close() and try:

with tf.Session() as sess:tf.global_variables_initializer().run()
https://en.xdnf.cn/q/117912.html

Related Q&A

How to insert integers into a list without indexing using python?

I am trying to insert values 0 - 9 into a list without indexing. For example if I have the list [4, 6, X, 9, 0, 1, 5, 7] I need to be able to insert the integers 0 - 9 into the placeholder X and test i…

Separate/reposition/translate shapes in image with pillow in python

I need to separate or translate or replace pixels in an image with python so as to all the shapes to share the same distance between each other and the limits of the canvas.background is white, shapes …

django.db.utils.OperationalError: (1045, Access denied for user user@localhost

I cant get my Django project to load my database correctly. It throws this error. Im running MariaDB with Django, and I uninstalled all MySQL I added the user by running:create database foo_db; create …

Paho Python Client with HiveMQ

i am developing a module in python that will allow me to connect my raspberry pi to a version of hivemq hosted on my pc.it connects normally but when i add hivemqs file auth plugin it doesnt seem to wo…

Comparison of value items in a dictionary and counting matches

Im using Python 2.7. Im trying to compare the value items in a dictionary.I have two problems. First is the iteration of values in a dictionary with a length of 1. I always get an error, because python…

how to send cookies inside post request

trying to send Post request with the cookies on my pc from get request #! /usr/bin/python import re #regex import urllib import urllib2 #get request x = urllib2.urlopen("http://www.example.com) #…

Flask-Uploads gives AttributeError?

from flask import Flask from flask.ext.uploads import UploadSet, configure_uploads, IMAGESapp = Flask(__name__)app.config[UPLOADED_PHOTOS_DEST] = /home/kevin photos = UploadSet(photos, IMAGES)configure…

Python: Alternate way to covert from base64 string to opencv

Im trying to convert this string in base64 (http://pastebin.com/uz4ta0RL) to something usable by OpenCV. Currently I am using this code (img_temp is the string) to use the OpenCV functions by convertin…

Move 3D plot to avoid clipping by margins

Im trying to figure out how I can get the 3D matplotlib images below to plot higher on the canvas so it doesnt get clipped. Here is the code Im using to create the plot. I couldnt find a way to attach …

HTML Link parsing using BeautifulSoup

here is my Python code which Im using to extract the Specific HTML from the Page links Im sending as parameter. Im using BeautifulSoup. This code works fine for sometimes and sometimes it is getting st…