Visualize strengths and weaknesses of a sample from pre-trained model

2024/10/12 0:26:05

Let's say I'm trying to predict an apartment price. So, I have a lot of labeled data, where on each apartment I have features that could affect the price like:

  • city
  • street
  • floor
  • year built
  • socioeconomic status
  • square feet
  • etc.

And I train a model, let's say XGBOOST. Now, I want to predict the price of a new apartment. Is there a good way to show what is "good" in this apartment, and what is bad, and by how much (scaled 0-1)?

For example: The floor number is a "strong" feature (i.e. - in this area this floor number is desired, thus affects positively on the price of the apartment), but the socioeconomic status is a weak feature (i.e. the socioeconomic status is low and thus affects negatively on the price of the apartment).

What I want is to illustrate more or less why my model decided on this price, and I want the user to get a feel of the apartment value by those indicators.

I thought of exhaustive search on each feature - but I'm afraid that will take too much time.

Is there a more brilliant way of doing this?

Any help would be much appreciated...

Answer

Happy news for you, there is.

A package called "SHAP" (SHapley Additive exPlanation) was recently released just for that purpose. Here's a link to the github.

It supports visualization of complicated models (which are hard to intuitively explain) like boosted trees (and XGBOOST in particular!)

It can show you "real" feature importance which is better than the "gain", "weight", and "cover" xgboost supplies as they are not consistent.

You can read all about why SHAP is better for feature evaluation here.

It will be hard to give you code that will work for you, but there is a good documentation and you should write one that suits you.

Here's the guide lines of building your first graph:

import shap
import xgboost as xgb# Assume X_train and y_train are both features and labels of data samplesdtrain = xgb.DMatrix(X_train, label=y_train, feature_names=feature_names, weight=weights_trn)# Train your xgboost model
bst = xgb.train(params0, dtrain, num_boost_round=2500, evals=watchlist, early_stopping_rounds=200)# "explainer" object of shap
explainer = shap.TreeExplainer(bst)# "Values you explain, I took them from my training set but you can "explain" here what ever you want
shap_values = explainer.shap_values(X_test)shap.summary_plot(shap_values, X_test)
shap.summary_plot(shap_values, X_test, plot_type="bar")

To plot the "Why a certain sample got its score" you can either use built in SHAP function for it (only works on a Jupyter Notebook). Perfect example here

I personally wrote a function that will plot it using matplotlib, which will take some effort.

Here is an example of a plot I've made using the shap values (features are confidential so all erased) enter image description here

You can see a 97% prediction to be label=1 and each feature and how much it added or negate from the log-loss, for that specific sample.

https://en.xdnf.cn/q/118260.html

Related Q&A

Scrapy get result in shell but not in script

one topic again ^^ Based on recommendations here, Ive implemented my bot the following and tested it all in shell :name_list = response.css("h2.label.title::text").extract()packaging_list = r…

How to find a source when a website uses javascript

What I want to achieve I am trying to scrape the website below using Beautiful-soup and when I load the page it does not give the table that shows various quotes. In my previous posts folks have helped…

How to print a list of dicts as an aligned table?

So after going through multiple questions regarding the alignment using format specifiers I still cant figure out why the numerical data gets printed to stdout in a wavy fashion.def create_data(soup_ob…

abstract classes in python: Enforcing type

My question is related to this question Is enforcing an abstract method implementation unpythonic? . I am using abstract classes in python but I realize that there is nothing that stops the user from …

Convert image array to original svs format

Im trying to apply a foreground extraction to a SVS image (Whole Slide Image) usign OpenSlide library.First, I converted my image to an array to work on my foreground extraction:image = np.asarray(oslI…

Printing bytestring via variable

I have the following Unicode text stored in variable:myvariable = Gen\xe8veWhat I want to do is to print myvariable and show this:GenveI tried this but failed:print myvariable.decode(utf-8)Whats the ri…

Loop and arrays of strings in python

I have the following data set:column1HL111 PG3939HL11 HL339PG RC--HL--PGI am attempting to write a function that does the following:Loop through each row of column1 Pull only the alphabet and put into…

2 Dendrograms + Heatmap from condensed correlationmatrix with scipy

I try to create something like this: plotting results of hierarchical clustering ontop of a matrix of data in pythonUnfortunatelly when I try to execute the code, I get the following warnings:Warning (…

Iterator example from Dive Into Python 3

Im learning Python as my 1st language from http://www.diveintopython3.net/. On Chp 7, http://www.diveintopython3.net/iterators.html, there is an example of how to use an iterator.import redef build_mat…

Getting a 500 Internal Server Error using render_template and Flask [duplicate]

This question already has answers here:How to debug a Flask app(13 answers)Comments not working in jinja2(2 answers)Closed 5 years ago.I am trying to use Flask to render an HTML template. I had it work…