How to get feature names of shap_values from TreeExplainer?

2024/9/8 10:30:49

I am doing a shap tutorial, and attempting to get the shap values for each person in a dataset

from sklearn.model_selection import train_test_split
import xgboost
import shap
import numpy as np
import pandas as pd
import matplotlib.pylab as plX,y =
X_display,y_display = create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = xgboost.DMatrix(X_train, label=y_train)
d_test = xgboost.DMatrix(X_test, label=y_test)
params = {"eta": 0.01,"objective": "binary:logistic","subsample": 0.5,"base_score": np.mean(y_train),"eval_metric": "logloss"
#model = xgboost.train(params, d_train, 5000, evals = [(d_test, "test")], verbose_eval=100, early_stopping_rounds=20)xg_clf = xgboost.XGBClassifier(), y_train)
explainer = shap.TreeExplainer(xg_clf, X_train)
#shap_values = explainer(X)
shap_values = explainer.shap_values(X)

going through the Python3 interpreter, shap_values is a massive array of 32,561 persons, each with a shap value for 12 features.

For example, the first individual has the following SHAP values:

>>> shap_values[0]
array([ 0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,-0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,-0.26587735,  0.02700199])

However, which value applies to which feature is a complete mystery to me.

the documentation says:

For models with a single output this returns a matrix of SHAP values(# samples x # features). Each row sums to the difference between the model output for thatsample and the expected value of the model output (which is stored in the expected_valueattribute of the explainer when it is constant). For models with vector outputs this returnsa list of such matrices, one for each output

When I go to explainer which produced shap_values I see that I can get feature names:

['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

but I cannot see how to get feature names within shap_values at the Python interpreter, if they're even there:

>>> shap_values.
shap_values.all(           shap_values.compress(      shap_values.dump(          shap_values.max(           shap_values.ravel(         shap_values.sort(          shap_values.tostring(
shap_values.any(           shap_values.conj(          shap_values.dumps(         shap_values.mean(          shap_values.real           shap_values.squeeze(       shap_values.trace(
shap_values.argmax(        shap_values.conjugate(     shap_values.fill(          shap_values.min(           shap_values.repeat(        shap_values.std(           shap_values.transpose(
shap_values.argmin(        shap_values.copy(          shap_values.flags          shap_values.nbytes         shap_values.reshape(       shap_values.strides        shap_values.var(
shap_values.argpartition(  shap_values.ctypes         shap_values.flat           shap_values.ndim           shap_values.resize(        shap_values.sum(           shap_values.view(
shap_values.argsort(       shap_values.cumprod(       shap_values.flatten(       shap_values.newbyteorder(  shap_values.round(         shap_values.swapaxes(      
shap_values.astype(        shap_values.cumsum(        shap_values.getfield(      shap_values.nonzero(       shap_values.searchsorted(  shap_values.T              
shap_values.base            shap_values.imag           shap_values.partition(     shap_values.setfield(      shap_values.take(          
shap_values.byteswap(      shap_values.diagonal(      shap_values.item(          shap_values.setflags(      shap_values.tobytes(       
shap_values.choose(           shap_values.itemset(       shap_values.ptp(           shap_values.shape          shap_values.tofile(        
shap_values.clip(          shap_values.dtype          shap_values.itemsize       shap_values.put(           shap_values.size           shap_values.tolist(    

My primary question: How can I figure out which feature in

['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

applies to which number in each row of shap_values?

>>> shap_values[0]
array([ 0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,-0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,-0.26587735,  0.02700199])

I would assume that the features are in the same order, but I have no evidence for that.

My secondary question: how can I find the feature names in shap_values?


The features are indeed in the same order, as you assume; see how to extract the most important feature names? and how to get feature names from explainer issues in Github.

To find the feature name, you simply need to access the element with the same index of the array with the names

For example:

shap_values = np.array([0.76437867, -0.11881508,  0.57451954, -0.41974955, -0.20982443,-0.38079952, -0.00986504,  0.32272505, -3.04392116,  0.00411322,-0.26587735,  0.02700199])
features_names = ['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation','Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss','Hours per week', 'Country']features_names[shap_values.argmin()]  # the index 8 -> Capital Gain
features_names[shap_values.argmax()]  # the index 0 -> Age

