Aspect module in dalex

In the real world, we come across data with dependencies. It is almost impossible to avoid dependence among predictors when building predictive models.

Unfortunately, many commonly used explainable artificial intelligence (XAI) methods ignore these dependencies, often assuming independence of variables (permutation methods), which leads to unrealistic settings and misleading explanations.

Problems with explaining models based on correlated data is one of the pitfalls described in General Pitfalls of Model-Agnostic Interpretation Methods for Machine Learning Models.

We propose a way in which ML engineers can explain their models taking into account the dependencies between the variables. The first part of the module are functionalities that enable estimating the importance and contribution of variables by grouping them in so called aspects. It is a method inpired by Triplot paper.

In [1]:
import dalex as dx
import numpy as np

import plotly
plotly.offline.init_notebook_mode()
In [2]:
dx.__version__
Out[2]:
'1.7.0'

Case study - german credit data

To showcase the abilities of the module, we will be using the German Credit Data dataset) to assign risk for each credit-seeker.

In [3]:
# read data and create model

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier

# credit data
data = dx.datasets.load_german()

# risk is the target
X = data.drop(columns='risk')
y = data.risk

categorical_features = ['sex', 'job', 'housing', 'saving_accounts', 'checking_account', 'purpose']
categorical_transformer = Pipeline(steps=[
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
])

numerical_features = ['age', 'duration', 'credit_amount']
numerical_transformer = Pipeline(steps=[
    ('scaler', StandardScaler())
])

preprocessor = ColumnTransformer(transformers=[
        ('cat', categorical_transformer, categorical_features),
        ('num', numerical_transformer, numerical_features)
])


classifier = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)

clf = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', classifier)
])

clf.fit(X, y)
Out[3]:
Pipeline(steps=[('preprocessor',
                 ColumnTransformer(transformers=[('cat',
                                                  Pipeline(steps=[('onehot',
                                                                   OneHotEncoder(handle_unknown='ignore'))]),
                                                  ['sex', 'job', 'housing',
                                                   'saving_accounts',
                                                   'checking_account',
                                                   'purpose']),
                                                 ('num',
                                                  Pipeline(steps=[('scaler',
                                                                   StandardScaler())]),
                                                  ['age', 'duration',
                                                   'credit_amount'])])),
                ('classifier',
                 RandomForestClassifier(max_depth=5, random_state=42))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

We already have the model, time to explain it - we create an Explainer object.

In [4]:
exp = dx.Explainer(clf, X, y)
Preparation of a new explainer is initiated

  -> data              : 1000 rows 9 cols
  -> target variable   : Parameter 'y' was a pandas.Series. Converted to a numpy.ndarray.
  -> target variable   : 1000 values
  -> model_class       : sklearn.ensemble._forest.RandomForestClassifier (default)
  -> label             : Not specified, model's class short name will be used. (default)
  -> predict function  : <function yhat_proba_default at 0x14ea34400> will be used (default)
  -> predict function  : Accepts only pandas.DataFrame, numpy.ndarray causes problems.
  -> predicted values  : min = 0.264, mean = 0.701, max = 0.919
  -> model type        : classification will be used (default)
  -> residual function : difference between y and yhat (default)
  -> residuals         : min = -0.893, mean = -0.00121, max = 0.578
  -> model_info        : package sklearn

A new explainer has been created!

Creating Aspect - finding dependencies

Now we create an Aspect object on the basis of an explainer. It enables the use of the dalex functionalities related to explanations in groups of dependent variables (aspects).

The Aspect object itself contains information about the dependencies between variables and their hierarchical clustering into aspects.

It is possible to choose the method of calculating the dependencies. By default, the so called association method is used, which consists in the use of statistical coefficients:

  • for two numerical variables: the association is the absolute value from the Spearman's rank correlation coefficient;
  • for two categorical variables: the association is the value of Cramér’s $V$ with bias correction (based on Pearson’s chi-squared statistic);
  • for one numerical and one categorical variable: the association is the value of eta-squared $\eta^2$ (based on H-statistic from Kruskal-Wallis test).

The user can also use the pps method - Power Predictive Score measure or provide their own method. It is worth noting that PPS is a more restrictive measure, i.e., it trims the less significant (often noise-related) dependencies to 0.

We will check what the variable hierarchical clustering looks like with both available methods.

In [5]:
asp = dx.Aspect(exp)
In [6]:
asp_pps = dx.Aspect(exp, depend_method = 'pps')
In [7]:
asp.plot_dendrogram(title='Hierarchical clustering dendrogram (with association)')
In [8]:
asp_pps.plot_dendrogram(title='Hierarchical clustering dendrogram (with PPS)')

Note that the resulting hierarchical clustering structure is based on a specified type of clustering. The default linkage method is 'complete', so during clustering, the variables with the lowest dependence are compared.

We can check the dependencies more broadly by analyzing the dependency matrix.

In [9]:
asp.depend_matrix
Out[9]:
job credit_amount duration age sex housing saving_accounts checking_account purpose
job 1.000000 0.298345 0.227266 0.041327 0.002733 0.020185 0.000000 0.000327 0.016128
credit_amount 0.298345 1.000000 0.624709 0.026298 0.012432 0.026204 0.011311 0.008628 0.062200
duration 0.227266 0.624709 1.000000 0.036316 0.004336 0.021894 0.004466 0.007272 0.032833
age 0.041327 0.026298 0.036316 1.000000 0.048645 0.107473 0.011411 0.009741 0.025006
sex 0.002733 0.012432 0.004336 0.048645 1.000000 0.228046 0.022107 0.000000 0.119456
housing 0.020185 0.026204 0.021894 0.107473 0.228046 1.000000 0.000000 0.082131 0.159854
saving_accounts 0.000000 0.011311 0.004466 0.011411 0.022107 0.000000 1.000000 0.164016 0.049834
checking_account 0.000327 0.008628 0.007272 0.009741 0.000000 0.082131 0.164016 1.000000 0.105944
purpose 0.016128 0.062200 0.032833 0.025006 0.119456 0.159854 0.049834 0.105944 1.000000

We can also see aspects created based on the chosen correlation cutoff level h (the minimal dependence between variables in one cluster) or based on the chosen maximum number of aspects n. The returned object is a dictionary in a form that allows interaction with other functions in the package.

In [10]:
asp.get_aspects(h=0.1)
Out[10]:
{'aspect_1': ['saving_accounts', 'checking_account'],
 'aspect_2': ['sex', 'housing', 'purpose'],
 'aspect_3': ['job', 'credit_amount', 'duration'],
 'age': ['age']}
In [11]:
asp_pps.get_aspects(n=5)
Out[11]:
{'aspect_1': ['age', 'housing', 'sex'],
 'aspect_2': ['checking_account', 'saving_accounts'],
 'aspect_3': ['credit_amount', 'duration'],
 'job': ['job'],
 'purpose': ['purpose']}

Creating explanations - triplots

Having already calculated the hierarchical clustering of variables into aspects, we can move on to creating explanations.

The new proposed type of explanation is the triplot. It is a tool that creates explanations based on the identified variables dependency structure. It can be used both on the global (model) level and on the local (predict) level.

The triplot analysis enables a deeper understanding of the influence of dependencies between the features on the model prediction, allows to find an appropriate approach to grouping features, and also provides a background for further model exploration.

Triplot gives a more holistic explanation of the importance of features by combining three panels:

  • (Local) variable importance – the importance of every single variable;
  • Hierarchical clustering – dependency structure between variables visualized by hierarchical clustering dendrogram;
  • Hierarchical aspect importance – the importance of groups of dependend variables determined by hierarchical clustering.

Model Triplot

So let's check what the triplot looks like in practice, starting with the analysis for entire model.

With this method, the importance is calculated using the permutation-based variable importance method provided by dalex, so it is possible to choose loss_function and type. The remaining parameters (N, B, and processes) allow for a compromise between the speed of calculations (for large datasets) and the stability of the results.

In [12]:
mt = asp.model_triplot(random_state=42)

The main result of the explanation is the importance (expressed in the dropout loss) of the aspects successively created during clustering.

In addition, the main result pd.DataFrame also contains information about the variables in a given aspect that have the smallest dependence (min_depend and vars_min_depend), which helps in the analysis of the resulting groups.

In [13]:
mt
Out[13]:
variable_names dropout_loss dropout_loss_change min_depend vars_min_depend label
0 [credit_amount, duration] 0.223219 0.080776 0.624709 [credit_amount, duration] RandomForestClassifier
1 [sex, housing] 0.170177 0.027735 0.228046 [sex, housing] RandomForestClassifier
2 [job, credit_amount, duration] 0.239478 0.097035 0.227266 [job, duration] RandomForestClassifier
3 [saving_accounts, checking_account] 0.371607 0.229165 0.164016 [saving_accounts, checking_account] RandomForestClassifier
4 [sex, housing, purpose] 0.191592 0.049149 0.119456 [sex, purpose] RandomForestClassifier
5 [job, credit_amount, duration, age] 0.266759 0.124316 0.026298 [credit_amount, age] RandomForestClassifier
6 [job, credit_amount, duration, age, sex, housi... 0.305335 0.162892 0.002733 [job, sex] RandomForestClassifier
7 [job, credit_amount, duration, age, sex, housi... 0.498256 0.355814 0.000000 [job, saving_accounts] RandomForestClassifier

However, the biggest advantage of this explanation is its visual form (which is what its name says).

In [14]:
mt.plot()

Looking at the triplot, we can simultaneously check which variables are most dependent on each other, and what their importance is, considering each of them both individually and in a group.

To facilitate such exploratory analysis of the explanation, there is also an option to create an interactive widget version of triplot. In the widget, after clicking on the dendrogram line, the plot elements corresponding to the selected aspect only are highlighted.

NOTE: ipywidgets are not visible in html, check jupyter notebook

In [15]:
mt.plot(widget=True)
Out[15]:
HBox(children=(FigureWidget({
    'data': [{'base': 0.1424428571428571,
              'hoverinfo': 'text',
   …

By analyzing the triplot, we can obtain a number of insights to help understand how the model treats variables with dependencies. This can improve for example our variable selection efforts.

Predict Triplot

The predict-level triplot, created for the selected observation, has a similar structure. We will check what this explanation looks like on the example of the observation for which the model was most uncertain as to the prediction.

In [16]:
# selected observation
ind = np.argmin(abs(asp.explainer.y_hat - 0.5))
X.iloc[[ind]]
Out[16]:
sex job housing saving_accounts checking_account credit_amount duration purpose age
951 male 2 own little little 2145 36 business 24
In [17]:
# y_true for selected observation
asp.explainer.y[ind]
Out[17]:
0
In [18]:
# y_hat predicted by model for selected observation
asp.explainer.y_hat[ind]
Out[18]:
0.4988224274107088

There are two possible ways of calculating importance on the local level:

  • the default one used in R package which corresponds to the approach to explanations used by LIME
  • the SHAP-based one (the implementation in dalex is used for calculations).

The method selection can be made by specifying the type parameter.

The remaining parameters (N, B, and processes) allow for a compromise between the speed of calculations (for large datasets) and the stability of the results. There are also parameters specific to the default method (sample_method, f) that determine how aspects are perturbed in the data.

In [19]:
pt_def = asp.predict_triplot(X.iloc[ind], random_state=42)

The main result of the explanation is very similar but we have also access to variable values.

Due to the method of calculating the importance of aspects, in the case of the last group, which includes all the variables, the importance is the difference in the model prediction for the analyzed observation and the mean prediction.

In [20]:
pt_def
Out[20]:
variable_names variable_values importance min_depend vars_min_depend label
0 [credit_amount, duration] [2145, 36] -0.037960 0.624709 [credit_amount, duration] RandomForestClassifier
1 [sex, housing] [male, own] 0.023036 0.228046 [sex, housing] RandomForestClassifier
2 [job, credit_amount, duration] [2, 2145, 36] -0.041137 0.227266 [job, duration] RandomForestClassifier
3 [saving_accounts, checking_account] [little, little] -0.149826 0.164016 [saving_accounts, checking_account] RandomForestClassifier
4 [sex, housing, purpose] [male, own, business] 0.016265 0.119456 [sex, purpose] RandomForestClassifier
5 [job, credit_amount, duration, age] [2, 2145, 36, 24] -0.078037 0.026298 [credit_amount, age] RandomForestClassifier
6 [job, credit_amount, duration, age, sex, housi... [2, 2145, 36, 24, male, own, business] -0.057361 0.002733 [job, sex] RandomForestClassifier
7 [job, credit_amount, duration, age, sex, housi... [2, 2145, 36, 24, male, own, little, little, b... -0.202384 0.0 [job, saving_accounts] RandomForestClassifier
In [21]:
pt_def.plot()

When the importance of aspects varies greatly, it may be difficult to read and analyze the middle part of the triplot. The solution is once again the possibility of using an interactive widget.

NOTE: ipywidgets are not visible in html, check jupyter notebook

In [22]:
pt_def.plot(widget=True)
Out[22]:
HBox(children=(FigureWidget({
    'data': [{'hoverinfo': 'text',
              'hoverlabel': {'bgcolor': 'rgba…

For the sake of completeness, let's generate a triplot using importance based on Shapley values.

In [23]:
pt_shap = asp.predict_triplot(X.iloc[ind], type='shap', random_state=42)
In [24]:
pt_shap.plot()

We can see that it was mainly the little saving accounts and checking account that influenced the prediction.

Creating explanations - aspect importance

The triplot explanations show the hierarchical importance of the aspects, which gives an insight into the behavior and decisions made by the analyzed model. However, these methods does not give the full picture of the situation at a given cutoff level.

For this purpose, we present Predict Aspect Importance and Model Aspect Importance methods. They do not use the entire hierarchical structure of dependencies, but create an explanation in certain specific aspects (at a selected cut-off level or created on the basis of domain knowledge, e.g. groups of variables with a similar meaning).

Model Aspect Importance

Model Aspect Importance is also calculated using the permutation-based variable importance method provided by dalex, so it is possible to choose the previously mentioned parameters.

By calling this method from the Aspect object, we can specify the groups in which we want to obtain explanations with the use of h - cut-off level (the minimum value of the dependency between the variables grouped in one aspect). Thus, if the triplot model was previously calculated using the same parameters, the method will use the already calculated results.

In [25]:
mai = asp.model_parts(h=0.1, label='for aspects created on treshold h=0.1')
In [26]:
mai.plot()

Moreover, you can modify the generated charts in plot() methods, similarly to other objects available in dalex.

In [27]:
mai.plot(show_variable_names=False, bar_width=15, digits=5)

An important functionality is the ability to generate results not only at a given cut-off level, but also for aspects created on other bases, primarily domain knowledge or for groups of variables of similar meaning.

In [28]:
aspects = {'bio': ['sex', 'age'],
           'personal': ['job', 'housing'],
           'credit': ['credit_amount', 'purpose', 'duration'],
           'accounts': ['saving_accounts', 'checking_account']}

mai_asp = asp.model_parts(variable_groups=aspects, label='for aspects created by user')

You can compare the obtained results by placing them on one plot. This is a good form of comparison when creating multiple models.

In [29]:
mai.plot(mai_asp)

Predict Aspect Importance

Let's take a look at the predict-level explanation. As before, the specifiable arguments coincide with those in the case of Predict Triplot. There are also two ways of calculating importance on the local level.

We can also call this method from the Aspect object and specify the groups in which we want to obtain explanations with the use of h - cut-off level.

In [30]:
pai = asp.predict_parts(X.iloc[0], h=0.1, type='shap', random_state=42, label='client No. 0')

Note that in the case of non triplot methods, the resulting data frames also contain information about the dependencies in the aspects.

In [31]:
pai
Out[31]:
aspect_name variable_names variable_values importance min_depend vars_min_depend label
0 aspect_1 [saving_accounts, checking_account] [not_known, little] -0.061472 0.164016 [saving_accounts, checking_account] client No. 0
1 aspect_3 [job, credit_amount, duration] [2, 1169, 6] 0.060977 0.227266 [job, duration] client No. 0
2 aspect_2 [sex, housing, purpose] [male, own, radio/TV] 0.036611 0.119456 [sex, purpose] client No. 0
3 age [age] [67] 0.020604 1.000000 [age, age] client No. 0
In [32]:
pai.plot()

If the default method is chosen to calculate the importance (coefficients in the surrogate model), we can use lasso regression and specify the maximum number of non-zero importances (n_aspects parameter).

In [33]:
pai_lasso = asp.predict_parts(X.iloc[0], h=0.1, type='default', n_aspects=3, random_state=42, label='client No. 0 (lasso)')
In [34]:
pai_lasso.plot()

It is also possible to easily compare the results of the explanations for different observations by placing explanations on one plot.

In [35]:
pai_1 = asp.predict_parts(X.iloc[1], h=0.1, type='shap', random_state=42, label='client No. 1')
pai_1.plot(pai)

Summary

The Aspect module in dalex is designed to facilitate the explanation process for groups of variables called aspects. This is done by keeping track of the hierarchical structure of dependencies between the variables represented by the dendrogram. It is also an alternative to explanations that operate only on single variables, which are often misleading.

Plots

This package uses plotly to render the plots:

  • Install extentions to use plotly in JupyterLab: Getting StartedTroubleshooting
  • Use show=False parameter in plot method to return plotly Figure object
  • It is possible to edit the figures and save them
  • Use widget=True parameter in plot method of triplot objects to return ipywidgets.HBox with plotly FigureWidget object

Resources - https://dalex.drwhy.ai/python