dalex - more explanations: residuals, shap, lime

imports

In [1]:
import dalex as dx 

import numpy as np
import pandas as pd

from lightgbm import LGBMRegressor
from sklearn.svm import SVR
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

import plotly
plotly.offline.init_notebook_mode()

import warnings
warnings.filterwarnings('ignore')
In [2]:
dx.__version__
Out[2]:
'1.7.0'

prepare data

Transform the skewed target variable (y) for better model fit.

In [3]:
data = dx.datasets.load_fifa()
X = data.drop(["nationality", "overall", "potential", "value_eur", "wage_eur"], axis = 1)
y = data['value_eur']

ylog = np.log(y)

create models

Use Pipeline to scale the data.

In [4]:
model_svm = Pipeline(steps=[('scale', StandardScaler()),
                            ('model', SVR(C=10, epsilon=0.2, tol=1e-4))])
model_svm.fit(X, ylog)
Out[4]:
Pipeline(steps=[('scale', StandardScaler()),
                ('model', SVR(C=10, epsilon=0.2, tol=0.0001))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
In [5]:
model_gbm = LGBMRegressor(n_estimators=200, max_depth=10, learning_rate=0.15, random_state=0, verbose=-1)
model_gbm.fit(X, ylog)
Out[5]:
LGBMRegressor(learning_rate=0.15, max_depth=10, n_estimators=200,
              random_state=0, verbose=-1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

predict_function

Because we transformed the the target, we want to change the default predict_function to return a real y value.

In [6]:
def predict_function(model, data):
    return np.exp(model.predict(data))

create an explainer for the model

Explainer prints useful information, especially for resolving potential errors.

In [7]:
exp_svm = dx.Explainer(model_svm, data=X, y=y,  predict_function=predict_function, label='svm')
exp_gbm = dx.Explainer(model_gbm, data=X, y=y, predict_function=predict_function, label='gbm')
Preparation of a new explainer is initiated

  -> data              : 5000 rows 37 cols
  -> target variable   : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray.
  -> target variable   : 5000 values
  -> model_class       : sklearn.svm._classes.SVR (default)
  -> label             : svm
  -> predict function  : <function predict_function at 0x2a5eb28e0> will be used
  -> predict function  : Accepts pandas.DataFrame and numpy.ndarray.
  -> predicted values  : min = 2.2e+05, mean = 7.25e+06, max = 8.69e+07
  -> model type        : regression will be used (default)
  -> residual function : difference between y and yhat (default)
  -> residuals         : min = -1.53e+07, mean = 2.19e+05, max = 1.86e+07
  -> model_info        : package sklearn

A new explainer has been created!
Preparation of a new explainer is initiated

  -> data              : 5000 rows 37 cols
  -> target variable   : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray.
  -> target variable   : 5000 values
  -> model_class       : lightgbm.sklearn.LGBMRegressor (default)
  -> label             : gbm
  -> predict function  : <function predict_function at 0x2a5eb28e0> will be used
  -> predict function  : Accepts pandas.DataFrame and numpy.ndarray.
  -> predicted values  : min = 2.01e+05, mean = 7.43e+06, max = 1.04e+08
  -> model type        : regression will be used (default)
  -> residual function : difference between y and yhat (default)
  -> residuals         : min = -6e+06, mean = 4.17e+04, max = 8.85e+06
  -> model_info        : package lightgbm

A new explainer has been created!

model_performance allows for easy model comparison.

In [8]:
pd.concat((exp_svm.model_performance().result, exp_gbm.model_performance().result))
Out[8]:
mse rmse r2 mae mad
svm 2.907931e+12 1.705266e+06 0.963016 950656.413963 531411.484922
gbm 4.142691e+11 6.436374e+05 0.994731 357519.691396 203774.166420

Above functionalities are accessible from the Explainer object through its methods.

Model-level and predict-level methods return a new unique object that contains the result attribute (pandas.DataFrame) and the plot method.

Features

shap wrapper

predict_parts and model_parts have new type='shap_wrapper' which uses the shap package to produce shap values explanations.

In [9]:
pp = exp_gbm.predict_parts(X.iloc[[1]], type='shap_wrapper', shap_explainer_type="TreeExplainer")
type(pp)
Out[9]:
dalex.wrappers._shap.object.ShapWrapper
In [10]:
pp.plot()
No description has been provided for this image
In [11]:
pp.result  # shap_values
Out[11]:
array([[-0.77081136,  0.01370978, -0.00173968,  0.07231857,  0.33443162,
         0.12392682,  0.21949412,  0.02376815,  0.13200212, -0.00302423,
         0.01465973,  0.02671936,  0.50762831,  0.04577954,  0.15363355,
         0.00569062,  0.9946803 ,  0.0136185 ,  0.06577841,  0.01979262,
         0.04086393,  0.01517965,  0.03015692, -0.00488528, -0.00664732,
         0.18990343,  0.06282931, -0.00595453,  0.02480912,  0.00717499,
        -0.02651436, -0.00365084, -0.00319156, -0.00773565,  0.00282692,
        -0.00140915, -0.00816135]])
In [12]:
mp = exp_gbm.model_parts(type='shap_wrapper', shap_explainer_type="TreeExplainer")
type(mp)
Out[12]:
dalex.wrappers._shap.object.ShapWrapper
In [13]:
mp.plot()
No description has been provided for this image
In [14]:
mp.plot(plot_type='bar')
No description has been provided for this image
In [15]:
mp.result  # shap_values
Out[15]:
array([[ 2.71880735e-01,  2.11009130e-04, -1.79258487e-03, ...,
        -4.94233078e-03, -2.04894884e-02, -2.49443416e-02],
       [ 1.74457071e-01, -3.89064082e-03,  1.39621906e-03, ...,
        -2.65455706e-03, -1.33083409e-02, -9.89234325e-03],
       [ 3.85514783e-01, -5.17611933e-03,  1.85647040e-03, ...,
        -1.72370497e-03, -1.10533060e-02, -1.39986819e-02],
       ...,
       [-6.03835606e-01, -5.73905545e-03,  3.21374227e-03, ...,
        -2.44509308e-03, -1.08883673e-02, -1.58151287e-02],
       [ 1.83601058e-01,  1.27521452e-03, -1.03295708e-02, ...,
        -3.68566607e-03, -8.33553004e-03, -7.78299112e-03],
       [ 2.86290183e-01, -1.51155855e-03,  8.58770597e-04, ...,
        -3.63211472e-03, -1.32611700e-02, -1.03296419e-02]])

model_diagnostics

New model_diagnostics method allows for Residual Diagnostics.

In [16]:
md_svm = exp_svm.model_diagnostics()
md_gbm = exp_gbm.model_diagnostics()
md_svm.plot(md_gbm, variable='age', yvariable='residuals', marker_size=5)

It can also be used for performing some Exploratory Dana Analysis.

In [17]:
md_svm.plot(variable='movement_reactions', yvariable='y', marker_size=5)

predict_surrogate

New predict_surrogate method uses the lime package to produce LIME explanations.

In [18]:
lime = exp_gbm.predict_surrogate(X.iloc[[1]])
type(lime)
Out[18]:
lime.explanation.Explanation
In [19]:
lime.plot()
No description has been provided for this image
In [20]:
lime.result
Out[20]:
variable effect
0 movement_reactions > 75.00 3.131761e+06
1 age > 30.00 -2.723991e+06
2 skill_ball_control > 76.00 1.738034e+06
3 attacking_finishing > 70.00 1.301945e+06
4 attacking_short_passing > 75.00 1.179459e+06
5 movement_sprint_speed > 77.00 7.282494e+05
6 attacking_heading_accuracy > 72.00 6.648493e+05
7 mentality_vision > 72.00 6.383244e+05
8 mentality_positioning > 73.00 5.614333e+05
9 skill_dribbling > 76.00 5.499609e+05

model_surrogate

New model_surrogate method allows for creating Global Surrogate models. For type='tree' a DecisionTree is fitted, which has additional performance attribute and the plot method that uses the sklearn.tree.plot_tree function.

In [21]:
surrogate_model_small = exp_gbm.model_surrogate(type='tree', max_depth=3, max_vars=3)
surrogate_model_small.performance
Out[21]:
mse rmse r2 mae mad
DecisionTreeRegressor 2.260452e+13 4.754421e+06 0.70541 2.972569e+06 1.802269e+06
In [22]:
surrogate_model_big = exp_gbm.model_surrogate(type='tree', max_depth=4, max_vars=4)
surrogate_model_big.performance
Out[22]:
mse rmse r2 mae mad
DecisionTreeRegressor 1.758802e+13 4.193808e+06 0.770787 2.663264e+06 1.708664e+06
In [23]:
surrogate_model_small.plot(figsize=(20, 8), fontsize=10, filled=True)
No description has been provided for this image
In [24]:
surrogate_model_big.plot(figsize=(20, 10), fontsize=9)
No description has been provided for this image
In [25]:
type(surrogate_model_big)
Out[25]:
sklearn.tree._classes.DecisionTreeRegressor

plot profiles in PDP and ALE

In [26]:
pdp = exp_gbm.model_profile(variables=['age', 'movement_reactions', 'skill_ball_control', 'attacking_short_passing'],
                            N=100)
Calculating ceteris paribus: 100%|██████████| 4/4 [00:00<00:00, 55.47it/s]
In [27]:
pdp.plot(geom='profiles')

repr explanations

dalex explanations now are represented with the result attribute.

In [28]:
pdp
Out[28]:
_vname_ _label_ _x_ _yhat_ _ids_
0 age gbm 16.00 1.106202e+07 0
1 age gbm 16.25 1.106202e+07 0
2 age gbm 16.50 1.106202e+07 0
3 age gbm 16.75 1.106202e+07 0
4 age gbm 17.00 1.106202e+07 0
... ... ... ... ... ...
399 attacking_short_passing gbm 88.76 9.137024e+06 0
400 attacking_short_passing gbm 89.57 9.137024e+06 0
401 attacking_short_passing gbm 90.38 9.137024e+06 0
402 attacking_short_passing gbm 91.19 9.137024e+06 0
403 attacking_short_passing gbm 92.00 9.137024e+06 0

404 rows × 5 columns

In [29]:
md_svm
Out[29]:
age height_cm weight_kg attacking_crossing attacking_finishing attacking_heading_accuracy attacking_short_passing attacking_volleys skill_dribbling skill_curve ... goalkeeping_handling goalkeeping_kicking goalkeeping_positioning goalkeeping_reflexes y y_hat residuals abs_residuals label ids
short_name
L. Messi 32 170 72 88 95 70 92 88 97 93 ... 11 15 14 8 95500000 8.067993e+07 1.482007e+07 1.482007e+07 svm 1
Cristiano Ronaldo 34 187 83 84 94 89 83 87 89 81 ... 11 15 14 11 58500000 6.582789e+07 -7.327888e+06 7.327888e+06 svm 2
Neymar Jr 27 175 68 87 87 62 87 87 96 88 ... 9 15 15 11 105500000 8.686368e+07 1.863632e+07 1.863632e+07 svm 3
J. Oblak 26 188 87 13 11 15 43 13 12 13 ... 92 78 90 89 77500000 6.345192e+07 1.404808e+07 1.404808e+07 svm 4
E. Hazard 28 175 74 81 84 61 89 83 95 83 ... 12 6 8 8 90000000 7.368687e+07 1.631313e+07 1.631313e+07 svm 5
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
E. Calaio 37 179 75 50 74 73 63 73 70 63 ... 9 5 9 8 625000 7.633953e+05 -1.383953e+05 1.383953e+05 svm 4996
W. Hoolahan 37 168 71 72 63 43 74 62 70 73 ... 7 16 11 16 600000 7.328320e+05 -1.328320e+05 1.328320e+05 svm 4997
B. Johnson 32 178 68 68 68 72 65 71 66 58 ... 6 7 15 11 1200000 1.465683e+06 -2.656828e+05 2.656828e+05 svm 4998
L. Clarke 34 188 89 50 71 74 63 67 62 44 ... 10 6 8 8 925000 1.129820e+06 -2.048201e+05 2.048201e+05 svm 4999
B. Bialkowski 31 193 86 11 13 12 33 12 14 11 ... 70 65 73 71 1100000 1.197971e+06 -9.797134e+04 9.797134e+04 svm 5000

5000 rows × 43 columns

Plots

This package uses plotly to render the plots:

Resources - https://dalex.drwhy.ai/python