tensorflow + dalex = :)

In [1]:
import warnings
warnings.filterwarnings('ignore')

import plotly
plotly.offline.init_notebook_mode()

read data

In [2]:
import pandas as pd
pd.__version__
Out[2]:
'1.5.3'
In [3]:
data = pd.read_csv("https://raw.githubusercontent.com/pbiecek/xai-happiness/main/happiness.csv", index_col=0)
data.head()
Out[3]:
score gdp_per_capita social_support healthy_life_expectancy freedom_to_make_life_choices generosity perceptions_of_corruption
Afghanistan 3.203 0.350 0.517 0.361 0.000 0.158 0.025
Albania 4.719 0.947 0.848 0.874 0.383 0.178 0.027
Algeria 5.211 1.002 1.160 0.785 0.086 0.073 0.114
Argentina 6.086 1.092 1.432 0.881 0.471 0.066 0.050
Armenia 4.559 0.850 1.055 0.815 0.283 0.095 0.064
In [4]:
X, y = data.drop('score', axis=1), data.score
n, p = X.shape

create a model

In [5]:
import tensorflow as tf
tf.__version__
Out[5]:
'2.16.0-rc0'
In [6]:
tf.random.set_seed(11)

normalizer  = tf.keras.layers.Normalization(input_shape=[p,])
normalizer.adapt(X.to_numpy())

model = tf.keras.Sequential([
    normalizer,
    tf.keras.layers.Dense(p*2, activation='relu'),
    tf.keras.layers.Dense(p*3, activation='relu'),
    tf.keras.layers.Dense(p*2, activation='relu'),
    tf.keras.layers.Dense(p, activation='relu'),
    tf.keras.layers.Dense(1, activation='linear')
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.mae
)
In [7]:
model.fit(X, y, batch_size=int(n/10), epochs=2000, verbose=False)
Out[7]:
<keras.src.callbacks.history.History at 0x2dd3edb20>

explain the model

Explainer initialization communicates useful information

In [8]:
import dalex as dx
dx.__version__
Out[8]:
'1.7.0'
In [9]:
explainer = dx.Explainer(model, X, y, label='happiness')
Preparation of a new explainer is initiated

  -> data              : 156 rows 6 cols
  -> target variable   : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray.
  -> target variable   : 156 values
  -> model_class       : keras.src.models.sequential.Sequential (default)
  -> label             : happiness
  -> predict function  : <function yhat_tf_regression at 0x2de68dbc0> will be used (default)
  -> predict function  : Accepts pandas.DataFrame and numpy.ndarray.
  -> predicted values  : min = 2.92, mean = 5.47, max = 7.83
  -> model type        : regression will be used (default)
  -> residual function : difference between y and yhat (default)
  -> residuals         : min = -1.06, mean = -0.061, max = 0.613
  -> model_info        : package keras

A new explainer has been created!

model level explanations

firstly, assess the model fit to training data

In [10]:
explainer.model_performance()
Out[10]:
mse rmse r2 mae mad
happiness 0.033685 0.183534 0.972638 0.100627 0.056535

which features are the most important?

In [11]:
explainer.model_parts().plot()

what are the continuous relationships between variables and predictions?

In [12]:
explainer.model_profile().plot(variables=['social_support', 'healthy_life_expectancy',
                                          'gdp_per_capita', 'freedom_to_make_life_choices'])
Calculating ceteris paribus: 100%|██████████| 6/6 [00:00<00:00,  6.52it/s]

what about residuals?

In [13]:
explainer.model_diagnostics().plot(variable='social_support', yvariable="abs_residuals", marker_size=5, line_width=3)
In [14]:
explainer.model_diagnostics().result
Out[14]:
gdp_per_capita social_support healthy_life_expectancy freedom_to_make_life_choices generosity perceptions_of_corruption y y_hat residuals abs_residuals label ids
Afghanistan 0.350 0.517 0.361 0.000 0.158 0.025 3.203 3.257266 -0.054266 0.054266 happiness 1
Albania 0.947 0.848 0.874 0.383 0.178 0.027 4.719 4.713201 0.005799 0.005799 happiness 2
Algeria 1.002 1.160 0.785 0.086 0.073 0.114 5.211 5.262719 -0.051719 0.051719 happiness 3
Argentina 1.092 1.432 0.881 0.471 0.066 0.050 6.086 6.093252 -0.007252 0.007252 happiness 4
Armenia 0.850 1.055 0.815 0.283 0.095 0.064 4.559 4.606464 -0.047464 0.047464 happiness 5
... ... ... ... ... ... ... ... ... ... ... ... ...
Venezuela 0.960 1.427 0.805 0.154 0.064 0.047 4.707 4.686743 0.020257 0.020257 happiness 152
Vietnam 0.741 1.346 0.851 0.543 0.147 0.073 5.175 5.488496 -0.313496 0.313496 happiness 153
Yemen 0.287 1.163 0.463 0.143 0.108 0.077 3.380 4.252674 -0.872674 0.872674 happiness 154
Zambia 0.578 1.058 0.426 0.431 0.247 0.087 4.107 4.175465 -0.068465 0.068465 happiness 155
Zimbabwe 0.366 1.114 0.433 0.361 0.151 0.089 3.663 3.735062 -0.072062 0.072062 happiness 156

156 rows × 12 columns

predict level explanations

investigate the specific country

In [15]:
explainer.predict_parts(X.loc['Poland'], type='shap').plot()

or several countries

In [16]:
pp_list = []
for country in ['Afghanistan', 'Belgium', 'China', 'Denmark', 'Ethiopia']:
    pp = explainer.predict_parts(X.loc[country], type='break_down')
    pp.result.label = country
    pp_list += [pp]
pp_list[0].plot(pp_list[1::], min_max=[2.5, 8.5])

surrogate approximation

In [17]:
lime_explanation = explainer.predict_surrogate(X.loc['United States'], mode='regression')
In [18]:
lime_explanation.plot()
No description has been provided for this image
In [19]:
lime_explanation.result
Out[19]:
variable effect
0 social_support > 1.45 0.668891
1 generosity > 0.25 0.488504
2 gdp_per_capita > 1.23 0.466195
3 0.09 < perceptions_of_corruption <= 0.14 -0.367146
4 0.79 < healthy_life_expectancy <= 0.88 0.180251
5 0.42 < freedom_to_make_life_choices <= 0.51 -0.128612

interpretable surrogate model

In [20]:
surrogate_model = explainer.model_surrogate(max_vars=4, max_depth=3)
surrogate_model.performance
Out[20]:
mse rmse r2 mae mad
DecisionTreeRegressor 0.180179 0.424475 0.847539 0.338915 0.300332
In [21]:
surrogate_model.plot()
No description has been provided for this image

Plots

This package uses plotly to render the plots:

Resources - https://dalex.drwhy.ai/python