Using a custom threshold value with tf.contrib.learn.DNNClassifier?

2024/10/13 5:20:23

I'm working on a binary classification problem and I'm using the tf.contrib.learn.DNNClassifier class within TensorFlow. When invoking this estimator for only 2 classes, it uses a threshold value of 0.5 as the cutoff between the 2 classes. I'd like to know if there's a way to use a custom threshold value since this might improve the model's accuracy.

I've searched all around the web and apparently there isn't a way to do this.

Any help will be greatly appreciated, thank you.

Answer

The tf.contrib.learn.DNNClassifier class has a method called predict_proba which returns the probabilities belonging to each class for the given inputs. Then you can use something like, tf.round(prob+thres) for binary thresholding with the custom parameter thres.

https://en.xdnf.cn/q/118119.html

Related Q&A

How to remove certain characters from a variable? (Python)

Lets suppose I have a variable called data. This data variable has all this data and I need to remove certain parts of it while keeping most of it. Lets say I needed to remove all the , (commas) in thi…

criticism this python code (crawler with threadpool)

how good this python code ? need criticism) there is a error in this code, some times script do print "ALL WAIT - CAN FINISH!" and freeze (no more actions are happend..) but i cant find reas…

cx_Freeze executable not displaying matplotlib figures

I am using Python 3.5 and I was able to create an executable using cx_Freeze but whenever I try to run the executable it runs without error but it cannot display any matplotlib figure. I have used Tkin…

Saving variables in n Entry widgets Tkinter interface

Firstly apologises for the length of code but I wanted to show it all.I have an interface that looks like this:When I change the third Option Menu to "List" I will add in the option to have n…

Pulling the href from a link when web scraping using Python

I am scraping from this page: https://www.pro-football-reference.com/years/2018/week_1.htmIt is a list of game scores for American Football. I want to open the link to the stats for the first game. The…

Php: Running a python script using blender from a php project using cmd commands

I need to run in cmd a python script for blender from blender and print the result from a php project, but I dont get the all result. Here is my code:$script = "C:\Users\madalina\Desktop\workspace…

Pymysql when executing Union query with %s (Parameter Placeholder)

This is the code about UNION QUERY:smith =Smithsmithb=Smithsql="""SELECT Distinct Pnumber FROM PROJECT, DEPARTMENT, EMPLOYEE WHERE Dnum = Dnumber AND Mgr_ssn=Ssn AND Lname= %s UNION SELE…

Django - Calling list or dict item using a variable in template

Im trying to call a dictionary or list object in a template using a variable in that template with no results.What Im trying to is identical to this general python code:keylist=[firstkey,secondkey,thir…

Multi-Classification NN with Keras error

I am getting an error when trying to do multi-classification with three classes. Error: TypeError: fit_generator() got multiple values for argument steps_per_epochCode Giving Error: NN.fit_generator(tr…

How to do time diff in each group on Pandas in Python

Heres the phony data:df = pd.DataFrame({email: [u1,u1,u1,u2,u2,u2],timestamp: [3, 1, 5, 11, 15, 9]})What I intend to retrieve is the time diff in each group of email. Thus, after sorting by timestamp i…