Aspect module in dalex¶
In the real world, we come across data with dependencies. It is almost impossible to avoid dependence among predictors when building predictive models.
Unfortunately, many commonly used explainable artificial intelligence (XAI) methods ignore these dependencies, often assuming independence of variables (permutation methods), which leads to unrealistic settings and misleading explanations.
Problems with explaining models based on correlated data is one of the pitfalls described in General Pitfalls of Model-Agnostic Interpretation Methods for Machine Learning Models.
We propose a way in which ML engineers can explain their models taking into account the dependencies between the variables. The first part of the module are functionalities that enable estimating the importance and contribution of variables by grouping them in so called aspects. It is a method inpired by Triplot paper.
import dalex as dx
import numpy as np
import plotly
plotly.offline.init_notebook_mode()
dx.__version__
Case study - german credit data¶
To showcase the abilities of the module, we will be using the German Credit Data dataset) to assign risk for each credit-seeker.
# read data and create model
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier
# credit data
data = dx.datasets.load_german()
# risk is the target
X = data.drop(columns='risk')
y = data.risk
categorical_features = ['sex', 'job', 'housing', 'saving_accounts', 'checking_account', 'purpose']
categorical_transformer = Pipeline(steps=[
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
numerical_features = ['age', 'duration', 'credit_amount']
numerical_transformer = Pipeline(steps=[
('scaler', StandardScaler())
])
preprocessor = ColumnTransformer(transformers=[
('cat', categorical_transformer, categorical_features),
('num', numerical_transformer, numerical_features)
])
classifier = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
clf = Pipeline(steps=[
('preprocessor', preprocessor),
('classifier', classifier)
])
clf.fit(X, y)
We already have the model, time to explain it - we create an Explainer
object.
exp = dx.Explainer(clf, X, y)
Creating Aspect
- finding dependencies¶
Now we create an Aspect
object on the basis of an explainer. It enables the use of the dalex
functionalities related to explanations in groups of dependent variables (aspects).
The Aspect
object itself contains information about the dependencies between variables and their hierarchical clustering into aspects.
It is possible to choose the method of calculating the dependencies. By default, the so called association
method is used, which consists in the use of statistical coefficients:
- for two numerical variables: the association is the absolute value from the Spearman's rank correlation coefficient;
- for two categorical variables: the association is the value of Cramér’s $V$ with bias correction (based on Pearson’s chi-squared statistic);
- for one numerical and one categorical variable: the association is the value of eta-squared $\eta^2$ (based on H-statistic from Kruskal-Wallis test).
The user can also use the pps
method - Power Predictive Score measure or provide their own method. It is worth noting that PPS is a more restrictive measure, i.e., it trims the less significant (often noise-related) dependencies to 0.
We will check what the variable hierarchical clustering looks like with both available methods.
asp = dx.Aspect(exp)
asp_pps = dx.Aspect(exp, depend_method = 'pps')
asp.plot_dendrogram(title='Hierarchical clustering dendrogram (with association)')
asp_pps.plot_dendrogram(title='Hierarchical clustering dendrogram (with PPS)')
Note that the resulting hierarchical clustering structure is based on a specified type of clustering. The default linkage method is 'complete'
, so during clustering, the variables with the lowest dependence are compared.
We can check the dependencies more broadly by analyzing the dependency matrix.
asp.depend_matrix
We can also see aspects created based on the chosen correlation cutoff level h
(the minimal dependence between variables in one cluster) or based on the chosen maximum number of aspects n
. The returned object is a dictionary in a form that allows interaction with other functions in the package.
asp.get_aspects(h=0.1)
asp_pps.get_aspects(n=5)
Creating explanations - triplots¶
Having already calculated the hierarchical clustering of variables into aspects, we can move on to creating explanations.
The new proposed type of explanation is the triplot. It is a tool that creates explanations based on the identified variables dependency structure. It can be used both on the global (model
) level and on the local (predict
) level.
The triplot analysis enables a deeper understanding of the influence of dependencies between the features on the model prediction, allows to find an appropriate approach to grouping features, and also provides a background for further model exploration.
Triplot gives a more holistic explanation of the importance of features by combining three panels:
- (Local) variable importance – the importance of every single variable;
- Hierarchical clustering – dependency structure between variables visualized by hierarchical clustering dendrogram;
- Hierarchical aspect importance – the importance of groups of dependend variables determined by hierarchical clustering.
Model Triplot¶
So let's check what the triplot looks like in practice, starting with the analysis for entire model.
With this method, the importance is calculated using the permutation-based variable importance method provided by dalex
, so it is possible to choose loss_function
and type
. The remaining parameters (N
, B
, and processes
) allow for a compromise between the speed of calculations (for large datasets) and the stability of the results.
mt = asp.model_triplot(random_state=42)
The main result of the explanation is the importance (expressed in the dropout loss) of the aspects successively created during clustering.
In addition, the main result pd.DataFrame
also contains information about the variables in a given aspect that have the smallest dependence (min_depend
and vars_min_depend
), which helps in the analysis of the resulting groups.
mt
However, the biggest advantage of this explanation is its visual form (which is what its name says).
mt.plot()
Looking at the triplot, we can simultaneously check which variables are most dependent on each other, and what their importance is, considering each of them both individually and in a group.
To facilitate such exploratory analysis of the explanation, there is also an option to create an interactive widget version of triplot. In the widget, after clicking on the dendrogram line, the plot elements corresponding to the selected aspect only are highlighted.
NOTE: ipywidgets are not visible in html, check jupyter notebook
mt.plot(widget=True)
By analyzing the triplot, we can obtain a number of insights to help understand how the model treats variables with dependencies. This can improve for example our variable selection efforts.
Predict Triplot¶
The predict-level triplot, created for the selected observation, has a similar structure. We will check what this explanation looks like on the example of the observation for which the model was most uncertain as to the prediction.
# selected observation
ind = np.argmin(abs(asp.explainer.y_hat - 0.5))
X.iloc[[ind]]
# y_true for selected observation
asp.explainer.y[ind]
# y_hat predicted by model for selected observation
asp.explainer.y_hat[ind]
There are two possible ways of calculating importance on the local level:
- the default one used in
R
package which corresponds to the approach to explanations used by LIME - the SHAP-based one (the implementation in
dalex
is used for calculations).
The method selection can be made by specifying the type
parameter.
The remaining parameters (N
, B
, and processes
) allow for a compromise between the speed of calculations (for large datasets) and the stability of the results. There are also parameters specific to the default method (sample_method
, f
) that determine how aspects are perturbed in the data.
pt_def = asp.predict_triplot(X.iloc[ind], random_state=42)
The main result of the explanation is very similar but we have also access to variable values.
Due to the method of calculating the importance of aspects, in the case of the last group, which includes all the variables, the importance is the difference in the model prediction for the analyzed observation and the mean prediction.
pt_def
pt_def.plot()
When the importance of aspects varies greatly, it may be difficult to read and analyze the middle part of the triplot. The solution is once again the possibility of using an interactive widget.
NOTE: ipywidgets are not visible in html, check jupyter notebook
pt_def.plot(widget=True)
For the sake of completeness, let's generate a triplot using importance based on Shapley values.
pt_shap = asp.predict_triplot(X.iloc[ind], type='shap', random_state=42)
pt_shap.plot()
We can see that it was mainly the little saving accounts and checking account that influenced the prediction.
Creating explanations - aspect importance¶
The triplot explanations show the hierarchical importance of the aspects, which gives an insight into the behavior and decisions made by the analyzed model. However, these methods does not give the full picture of the situation at a given cutoff level.
For this purpose, we present Predict Aspect Importance and Model Aspect Importance methods. They do not use the entire hierarchical structure of dependencies, but create an explanation in certain specific aspects (at a selected cut-off level or created on the basis of domain knowledge, e.g. groups of variables with a similar meaning).
Model Aspect Importance¶
Model Aspect Importance is also calculated using the permutation-based variable importance method provided by dalex
, so it is possible to choose the previously mentioned parameters.
By calling this method from the Aspect
object, we can specify the groups in which we want to obtain explanations with the use of h
- cut-off level (the minimum value of the dependency between the variables grouped in one aspect). Thus, if the triplot model was previously calculated using the same parameters, the method will use the already calculated results.
mai = asp.model_parts(h=0.1, label='for aspects created on treshold h=0.1')
mai.plot()
Moreover, you can modify the generated charts in plot()
methods, similarly to other objects available in dalex
.
mai.plot(show_variable_names=False, bar_width=15, digits=5)
An important functionality is the ability to generate results not only at a given cut-off level, but also for aspects created on other bases, primarily domain knowledge or for groups of variables of similar meaning.
aspects = {'bio': ['sex', 'age'],
'personal': ['job', 'housing'],
'credit': ['credit_amount', 'purpose', 'duration'],
'accounts': ['saving_accounts', 'checking_account']}
mai_asp = asp.model_parts(variable_groups=aspects, label='for aspects created by user')
You can compare the obtained results by placing them on one plot. This is a good form of comparison when creating multiple models.
mai.plot(mai_asp)
Predict Aspect Importance¶
Let's take a look at the predict-level explanation. As before, the specifiable arguments coincide with those in the case of Predict Triplot. There are also two ways of calculating importance on the local level.
We can also call this method from the Aspect
object and specify the groups in which we want to obtain explanations with the use of h
- cut-off level.
pai = asp.predict_parts(X.iloc[0], h=0.1, type='shap', random_state=42, label='client No. 0')
Note that in the case of non triplot methods, the resulting data frames also contain information about the dependencies in the aspects.
pai
pai.plot()
If the default method is chosen to calculate the importance (coefficients in the surrogate model), we can use lasso regression and specify the maximum number of non-zero importances (n_aspects
parameter).
pai_lasso = asp.predict_parts(X.iloc[0], h=0.1, type='default', n_aspects=3, random_state=42, label='client No. 0 (lasso)')
pai_lasso.plot()
It is also possible to easily compare the results of the explanations for different observations by placing explanations on one plot.
pai_1 = asp.predict_parts(X.iloc[1], h=0.1, type='shap', random_state=42, label='client No. 1')
pai_1.plot(pai)
Summary¶
The Aspect module in dalex
is designed to facilitate the explanation process for groups of variables called aspects. This is done by keeping track of the hierarchical structure of dependencies between the variables represented by the dendrogram. It is also an alternative to explanations that operate only on single variables, which are often misleading.
Plots¶
This package uses plotly to render the plots:
- Install extentions to use
plotly
in JupyterLab: Getting Started Troubleshooting - Use
show=False
parameter inplot
method to returnplotly Figure
object - It is possible to edit the figures and save them
- Use
widget=True
parameter inplot
method of triplot objects to returnipywidgets.HBox
withplotly FigureWidget
object
Resources - https://dalex.drwhy.ai/python¶
Introduction to the
dalex
package: Titanic: tutorial and examplesKey features explained: FIFA20: explain default vs tuned model with dalex
How to use dalex with: xgboost, tensorflow, h2o (feat. autokeras, catboost, lightgbm)
More explanations: residuals, shap, lime
Introduction to the Fairness module in dalex
Introduction to the Aspect module in dalex
Introduction to Arena: interactive dashboard for model exploration
Code in the form of jupyter notebook
Changelog: NEWS
Theoretical introduction to the plots: Explanatory Model Analysis: Explore, Explain, and Examine Predictive Models