How to get feature names of shap_values from TreeExplainer?

2024/9/8 10:30:49

I am doing a shap tutorial, and attempting to get the shap values for each person in a dataset

from sklearn.model_selection import train_test_split
import xgboost
import shap
import numpy as np
import pandas as pd
import matplotlib.pylab as plX,y = shap.datasets.adult()
X_display,y_display = shap.datasets.adult(display=True)# create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = xgboost.DMatrix(X_train, label=y_train)
d_test = xgboost.DMatrix(X_test, label=y_test)
params = {"eta": 0.01,"objective": "binary:logistic","subsample": 0.5,"base_score": np.mean(y_train),"eval_metric": "logloss"
}
#model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=100, early_stopping_rounds=20)xg_clf = xgboost.XGBClassifier()
xg_clf.fit(X_train, y_train)
explainer = shap.TreeExplainer(xg_clf, X_train)
#shap_values = explainer(X)
shap_values = explainer.shap_values(X)

going through the Python3 interpreter, shap_values is a massive array of 32,561 persons, each with a shap value for 12 features.

For example, the first individual has the following SHAP values:

>>> shap_values[0]
array([ 0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,-0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,-0.26587735,  0.02700199])

However, which value applies to which feature is a complete mystery to me.

the documentation says:

For models with a single output this returns a matrix of SHAP values(# samples x # features). Each row sums to the difference between the model output for thatsample and the expected value of the model output (which is stored in the expected_valueattribute of the explainer when it is constant). For models with vector outputs this returnsa list of such matrices, one for each output

When I go to explainer which produced shap_values I see that I can get feature names:

explainer.data_feature_names
['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

but I cannot see how to get feature names within shap_values at the Python interpreter, if they're even there:

>>> shap_values.
shap_values.all(           shap_values.compress(      shap_values.dump(          shap_values.max(           shap_values.ravel(         shap_values.sort(          shap_values.tostring(
shap_values.any(           shap_values.conj(          shap_values.dumps(         shap_values.mean(          shap_values.real           shap_values.squeeze(       shap_values.trace(
shap_values.argmax(        shap_values.conjugate(     shap_values.fill(          shap_values.min(           shap_values.repeat(        shap_values.std(           shap_values.transpose(
shap_values.argmin(        shap_values.copy(          shap_values.flags          shap_values.nbytes         shap_values.reshape(       shap_values.strides        shap_values.var(
shap_values.argpartition(  shap_values.ctypes         shap_values.flat           shap_values.ndim           shap_values.resize(        shap_values.sum(           shap_values.view(
shap_values.argsort(       shap_values.cumprod(       shap_values.flatten(       shap_values.newbyteorder(  shap_values.round(         shap_values.swapaxes(      
shap_values.astype(        shap_values.cumsum(        shap_values.getfield(      shap_values.nonzero(       shap_values.searchsorted(  shap_values.T              
shap_values.base           shap_values.data           shap_values.imag           shap_values.partition(     shap_values.setfield(      shap_values.take(          
shap_values.byteswap(      shap_values.diagonal(      shap_values.item(          shap_values.prod(          shap_values.setflags(      shap_values.tobytes(       
shap_values.choose(        shap_values.dot(           shap_values.itemset(       shap_values.ptp(           shap_values.shape          shap_values.tofile(        
shap_values.clip(          shap_values.dtype          shap_values.itemsize       shap_values.put(           shap_values.size           shap_values.tolist(    

My primary question: How can I figure out which feature in

['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

applies to which number in each row of shap_values?

>>> shap_values[0]
array([ 0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,-0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,-0.26587735,  0.02700199])

I would assume that the features are in the same order, but I have no evidence for that.

My secondary question: how can I find the feature names in shap_values?

Answer

The features are indeed in the same order, as you assume; see how to extract the most important feature names? and how to get feature names from explainer issues in Github.

To find the feature name, you simply need to access the element with the same index of the array with the names

For example:

shap_values = np.array([0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,-0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,-0.26587735,  0.02700199])
features_names = ['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation','Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss','Hours per week', 'Country']features_names[shap_values.argmin()]  # the index 8 -> Capital Gain
features_names[shap_values.argmax()]  # the index 0 -> Age
https://en.xdnf.cn/q/72728.html

Related Q&A

How can I clear a line in console after using \r and printing some text?

For my current project, there are some pieces of code that are slow and which I cant make faster. To get some feedback how much was done / has to be done, Ive created a progress snippet which you can s…

installing pyaudio to docker container

I am trying to install pyaudio to my docker container and I was wondering if anyone had any solution for Windows. I have tried two methods: Method 1: Using pipwin - Error Code: => [3/7] RUN pip inst…

Escaping special characters in elasticsearch

I am using the elasticsearch python client to make some queries to the elasticsearch instance that we are hosting.I noticed that some characters need to be escaped. Specifically, these...+ - &&…

Interacting with live matplotlib plot

Im trying to create a live plot which updates as more data is available.import os,sys import matplotlib.pyplot as pltimport time import randomdef live_plot():fig = plt.figure()ax = fig.add_subplot(111)…

pandas groupby: can I select an agg function by one level of a column MultiIndex?

I have a pandas DataFrame with a MultiIndex of columns:columns=pd.MultiIndex.from_tuples([(c, i) for c in [a, b] for i in range(3)]) df = pd.DataFrame(np.random.randn(4, 6),index=[0, 0, 1, 1],columns=c…

Bottle web app not serving static css files

My bottle web application is not serving my main.css file despite the fact I am using the static_file method.app.pyfrom bottle import * from xml.dom import minidom @route(/) def index():return template…

How to wrap text in OpenCV when I print it on an image and it exceeds the frame of the image?

I have a 1:1 ratio image and I want to make sure that if the text exceeds the frame of the image, it gets wrapped to the next line. How would I do it?I am thinking of doing an if-else block, where &qu…

pandas series filtering between values

If s is a pandas.Series, I know I can do this:b = s < 4or b = s > 0but I cant dob = 0 < s < 4orb = (0 < s) and (s < 4)What is the idiomatic pandas method for creating a boolean series…

python os.path.exists reports False when files is there

Hi have an application which is sometimes reporting that a file does not exist even when it does, I am using os.path.exists and the file is on a mounted network share. I am on OSX Yosemite, python 2.7.…

Python unhashable type: numpy.ndarray

I worked on making functions for K Nearest Neighbors. I have tested each function separately and they all work well. However whenever I put them together and run KNN_method, it shows unhashable type: n…