Module dalex.aspect

Expand source code Browse git
from .object import Aspect
from ._predict_aspect_importance import PredictAspectImportance
from ._model_aspect_importance import ModelAspectImportance
from ._predict_triplot import PredictTriplot
from ._model_triplot import ModelTriplot

__all__ = [
    "Aspect",
    "PredictAspectImportance",
    "ModelAspectImportance",
    "PredictTriplot",
    "ModelTriplot"
    ]

Sub-modules

dalex.aspect.checks
dalex.aspect.object
dalex.aspect.plot
dalex.aspect.utils

Classes

class Aspect (explainer, depend_method='assoc', clust_method='complete', corr_method='spearman', agg_method='max')

Create Aspect

Explanation methods that do not take into account dependencies between variables can produce misleading results. This class creates a representation of a model based on an Explainer object. In addition, it calculates the relationships between the variables that can be used to create explanations. Methods of this class produce explanation objects, that contain the main result attribute, and can be visualised using the plot method.

The explainer is the only required parameter.

Parameters

explainer : Explainer dalex.aspect.object
Model wrapper created using the Explainer class.
depend_method : {'assoc', 'pps'} or function, optional
The method of calculating the dependencies between variables (i.e. the dependency matrix). Default is 'assoc', which means the use of statistical association (correlation coefficient, Cramér's V based on Pearson's chi-squared statistic and eta-quared based on Kruskal-Wallis H-statistic); 'pps' stands for Power Predictive Score. NOTE: When a function is passed, it is called with the explainer.data and it must return a symmetric dependency matrix (pd.DataFrame with variable names as columns and rows).
clust_method : {'complete', 'single', 'average', 'weighted', 'centroid', 'median', 'ward'}, optional
The linkage algorithm to use for variables hierarchical clustering (default is 'complete').
corr_method : {'spearman', 'pearson', 'kendall'}, optional
The method of calculating correlation between numerical variables (default is 'spearman'). NOTE: Ignored if depend_method is not 'assoc'.
agg_method : {'max', 'min', 'avg'}, optional
The method of aggregating the PPS values for pairs of variables (default is 'max'). NOTE: Ignored if depend_method is not 'pps'.

Attributes

explainer : Explainer dalex.aspect.object
Model wrapper created using the Explainer class.
depend_method : {'assoc', 'pps'} or function
The method of calculating the dependencies between variables.
clust_method : {'complete', 'single', 'average', 'weighted', 'centroid', 'median', 'ward'}
The linkage algorithm to use for variables hierarchical clustering.
corr_method : {'spearman', 'pearson', 'kendall'}
The method of calculating correlation between numerical variables.
agg_method : {'max', 'min', 'avg'}
The method of aggregating the PPS values for pairs of variables.
depend_matrix : pd.DataFrame
The dependency matrix (with variable names as columns and rows).

linkage_matrix : The hierarchical clustering of variables encoded as a scipy linkage matrix.

Notes

Expand source code Browse git
class Aspect:
    """Create Aspect

    Explanation methods that do not take into account dependencies between variables
    can produce misleading results. This class creates a representation of a model based
    on an Explainer object. In addition, it calculates the relationships between
    the variables that can be used to create explanations. Methods of this class produce
    explanation objects, that contain the main result attribute, and can be visualised
    using the plot method.

    The `explainer` is the only required parameter.

    Parameters
    ----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.
    depend_method: {'assoc', 'pps'} or function, optional
        The method of calculating the dependencies between variables (i.e. the dependency
        matrix). Default is `'assoc'`, which means the use of statistical association
        (correlation coefficient, Cramér's V based on Pearson's chi-squared statistic 
        and eta-quared based on Kruskal-Wallis H-statistic);
        `'pps'` stands for Power Predictive Score.
        NOTE: When a function is passed, it is called with the `explainer.data` and it
        must return a symmetric dependency matrix (`pd.DataFrame` with variable names as
        columns and rows).
    clust_method : {'complete', 'single', 'average', 'weighted', 'centroid', 'median', 'ward'}, optional
        The linkage algorithm to use for variables hierarchical clustering
        (default is `'complete'`).
    corr_method : {'spearman', 'pearson', 'kendall'}, optional
        The method of calculating correlation between numerical variables
        (default is `'spearman'`).
        NOTE: Ignored if `depend_method` is not `'assoc'`.
    agg_method : {'max', 'min', 'avg'}, optional
        The method of aggregating the PPS values for pairs of variables
        (default is `'max'`).
        NOTE: Ignored if `depend_method` is not `'pps'`.

    Attributes
    --------
    explainer : Explainer object
        Model wrapper created using the Explainer class.
    depend_method : {'assoc', 'pps'} or function
        The method of calculating the dependencies between variables.
    clust_method : {'complete', 'single', 'average', 'weighted', 'centroid', 'median', 'ward'}
        The linkage algorithm to use for variables hierarchical clustering.
    corr_method : {'spearman', 'pearson', 'kendall'}
        The method of calculating correlation between numerical variables.
    agg_method : {'max', 'min', 'avg'}
        The method of aggregating the PPS values for pairs of variables.
    depend_matrix : pd.DataFrame
        The dependency matrix (with variable names as columns and rows).
    linkage_matrix :
        The hierarchical clustering of variables encoded as a `scipy` linkage matrix.

    Notes
    -----
    - assoc, eta-squared: http://tss.awf.poznan.pl/files/3_Trends_Vol21_2014__no1_20.pdf
    - assoc, Cramér's V: http://stats.lse.ac.uk/bergsma/pdf/cramerV3.pdf
    - PPS: https://github.com/8080labs/ppscore
    - triplot: https://arxiv.org/abs/2104.03403
    """

    def __init__(
        self,
        explainer,
        depend_method="assoc",
        clust_method="complete",
        corr_method="spearman",
        agg_method="max",
    ):  
        _depend_method, _corr_method, _agg_method = checks.check_method_depend(depend_method, corr_method, agg_method)
        self.explainer = explainer
        self.depend_method = _depend_method
        self.clust_method = clust_method
        self.corr_method = _corr_method
        self.agg_method = _agg_method
        self.depend_matrix = utils.calculate_depend_matrix(
            self.explainer.data, self.depend_method, self.corr_method, self.agg_method
        )
        self.linkage_matrix = utils.calculate_linkage_matrix(
            self.depend_matrix, clust_method
        )
        self._hierarchical_clustering_dendrogram = plot.plot_dendrogram(
            self.linkage_matrix, self.depend_matrix.columns
        )
        self._dendrogram_aspects_ordered = utils.get_dendrogram_aspects_ordered(
            self._hierarchical_clustering_dendrogram, self.depend_matrix
        )
        self._full_hierarchical_aspect_importance = None
        self._mt_params = None

    def get_aspects(self, h=0.5, n=None):
        from scipy.cluster.hierarchy import fcluster
        """Form aspects of variables from the hierarchical clustering

        Parameters
        ----------
        h : float, optional
            Threshold to apply when forming aspects, i.e., the minimum value of the dependency
            between the variables grouped in one aspect (default is `0.5`).
            NOTE: Ignored if `n` is not `None`.
        n : int, optional
            Maximum number of aspects to form 
            (default is `None`, which means the use of `h` parameter).

        Returns
        -------
        dict of lists
            Variables grouped in aspects, e.g. `{'aspect_1': ['x1', 'x2'], 'aspect_2': ['y1', 'y2']}`.
        """
        if n is None:
            aspect_label = fcluster(self.linkage_matrix, 1 - h, criterion="distance")
        else:
            aspect_label = fcluster(self.linkage_matrix, n, criterion="maxclust")
        aspects = pd.DataFrame(
            {"feature": self.depend_matrix.columns, "aspect": aspect_label}
        )
        aspects = aspects.groupby("aspect")["feature"].apply(list).reset_index()
        aspects_dict = {}

        # rename an aspect when there is a single variable in it
        i = 1
        for index, row in aspects.iterrows():
            if len(row["feature"]) > 1:
                aspects_dict[f"aspect_{i}"] = row["feature"]
                i += 1
            else:
                aspects_dict[row["feature"][0]] = row["feature"]

        return aspects_dict

    def plot_dendrogram(
        self,
        title="Hierarchical clustering dendrogram",
        lines_interspace=20,
        rounding_function=np.round,
        digits=3,
        show=True,
    ):
        """Plot the hierarchical clustering dendrogram of variables

        Parameters
        ----------
        title : str, optional
            Title of the plot (default is "Hierarchical clustering dendrogram").
        lines_interspace : float, optional
            Interspace between lines of dendrogram in px (default is `20`).
        rounding_function : function, optional
            A function that will be used for rounding numbers (default is `np.around`).
        digits : int, optional
            Number of decimal places (`np.around`) to round contributions.
            See `rounding_function` parameter (default is `3`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """
        m = len(self.depend_matrix.columns)
        plot_height = 78 + 71 + m * lines_interspace + (m + 1) * lines_interspace / 4
        fig = self._hierarchical_clustering_dendrogram
        fig = plot.add_text_and_tooltips_to_dendrogram(
            fig, self._dendrogram_aspects_ordered, rounding_function, digits
        )
        fig = plot._add_points_on_dendrogram_traces(fig)
        fig.update_layout(
            title={"text": title, "x": 0.15},
            yaxis={"automargin": True, "autorange": "reversed"},
            height=plot_height,
        )
        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

    def predict_parts(
        self,
        new_observation,
        variable_groups=None,
        type="default",
        h=0.5,
        N=2000,
        B=25,
        n_aspects=None,
        sample_method="default",
        f=2,
        label=None,
        processes=1,
        random_state=None,
    ):
        """Calculate predict-level aspect importance

        Parameters
        ----------
        new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
            An observation for which a prediction needs to be explained.
        variable_groups : dict of lists or None
            Variables grouped in aspects to calculate their importance (default is `None`).
        type : {'default', 'shap'}, optional
            Type of aspect importance/attributions (default is `'default'`, which means
            the use of simplified LIME method).
        h : float, optional
            Threshold to apply when forming aspects, i.e., the minimum value of the dependency
            between the variables grouped in one aspect (default is `0.5`).
        N : int, optional
            Number of observations that will be sampled from the `explainer.data` attribute
            before the calculation of aspect importance (default is `2000`).
        B : int, optional
            Parameter specific for `type == 'shap'`. Number of random paths to calculate aspect
            attributions (default is `25`).
            NOTE: Ignored if `type` is not `'shap'`.
        n_aspects : int, optional
            Parameter specific for `type == 'default'`. Maximum number of non-zero importances, i.e.
            coefficients after lasso fitting (default is `None`, which means the linear regression is used).
            NOTE: Ignored if `type` is not `'default'`.
        sample_method : {'default', 'binom'}, optional
            Parameter specific for `type == 'default'`. Sampling method for creating binary matrix
            used as mask for replacing aspects in sampled data (default is `'default'`, which means
            it randomly replaces one or two zeros per row; `'binom'` replaces random number of zeros
            per row).
            NOTE: Ignored if `type` is not `'default'`.
        f : int, optional
            Parameter specific for `type == 'default'` and `sample_method == 'binom'`. Parameter
            controlling average number of replaced zeros for binomial sampling (default is `2`).
            NOTE: Ignored if `type` is not `'default'` or `sample_method` is not `'binom'`.
        label : str, optional
            Name to appear in result and plots. Overrides default.
        processes : int, optional
            Parameter specific for `type == 'shap'`. Number of parallel processes to use in calculations.
            Iterated over `B` (default is `1`, which means no parallel computation).
        random_state : int, optional
            Set seed for random number generator (default is random seed).

        Returns
        -------
        PredictAspectImportance class object
            Explanation object containing the main result attribute and the plot method.
        """

        if variable_groups is None:
            variable_groups = self.get_aspects(h)

        pai = PredictAspectImportance(
            variable_groups,
            type,
            N,
            B,
            n_aspects,
            sample_method,
            f,
            self.depend_method,
            self.corr_method,
            self.agg_method,
            processes,
            random_state,
            _depend_matrix=self.depend_matrix
        )

        pai.fit(self.explainer, new_observation)

        if label is not None:
            pai.result["label"] = label

        return pai

    def model_parts(
        self,
        variable_groups=None,
        h=0.5,
        loss_function=None,
        type="variable_importance",
        N=1000,
        B=10,
        processes=1,
        label=None,
        random_state=None,
    ):
        """Calculate model-level aspect importance

        Parameters
        ----------
        variable_groups : dict of lists or None
            Variables grouped in aspects to calculate their importance (default is `None`).
        h : float, optional
            Threshold to apply when forming aspects, i.e., the minimum value of the dependency
            between the variables grouped in one aspect (default is `0.5`).
        loss_function :  {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
            If string, then such loss function will be used to assess aspect importance
            (default is `'rmse'` or `'1-auc'`, depends on `explainer.model_type` attribute).
        type : {'variable_importance', 'ratio', 'difference'}, optional
            Type of transformation that will be applied to dropout loss
            (default is `'variable_importance'`, which is Permutational Variable Importance).
        N : int, optional
            Number of observations that will be sampled from the `explainer.data` attribute before
            the calculation of aspect importance. `None` means all `data` (default is `1000`).
        B : int, optional
            Number of permutation rounds to perform on each variable (default is `10`).
        processes : int, optional
            Number of parallel processes to use in calculations. Iterated over `B`
            (default is `1`, which means no parallel computation).
        label : str, optional
            Name to appear in result and plots. Overrides default.
        random_state : int, optional
            Set seed for random number generator (default is random seed).

        Returns
        -------
        ModelAspectImportance class object
            Explanation object containing the main result attribute and the plot method.
        """

        loss_function = checks.check_method_loss_function(self.explainer, loss_function)
        mai_result = None

        if variable_groups is None:
            variable_groups = self.get_aspects(h)

            # get results from triplot if it was precalculated with the same params
            if self._full_hierarchical_aspect_importance is not None:
                if (
                    self._mt_params["loss_function"] == loss_function
                    and self._mt_params["N"] == N
                    and self._mt_params["B"] == B
                    and self._mt_params["type"] == type
                ):
                    h = min(1, h)
                    h_selected = np.unique(
                        self._full_hierarchical_aspect_importance.loc[
                            self._full_hierarchical_aspect_importance.h >= h
                        ].h
                    )[0]
                    mai_result = self._full_hierarchical_aspect_importance.loc[
                        self._full_hierarchical_aspect_importance.h == h_selected
                    ]

        ai = ModelAspectImportance(
            loss_function=loss_function,
            type=type,
            N=N,
            B=B,
            variable_groups=variable_groups,
            processes=processes,
            random_state=random_state,
            _depend_matrix=self.depend_matrix
        )

        # calculate if there was no results
        if mai_result is None:
            ai.fit(self.explainer)
        else: 
            mai_result = mai_result[
                [
                    "aspect_name",
                    "variable_names",
                    "dropout_loss",
                    "dropout_loss_change",
                    "min_depend",
                    "vars_min_depend",
                    "label",
                ]
            ]
            ai.result = mai_result

        if label is not None:
            ai.result["label"] = label

        return ai

    def predict_triplot(
        self,
        new_observation,
        type="default",
        N=2000,
        B=25,
        sample_method="default",
        f=2,
        processes=1,
        random_state=None,
    ):
        """Calculate predict-level hierarchical aspect importance

        Parameters
        ----------
        new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
            An observation for which a prediction needs to be explained.
        type : {'default', 'shap'}, optional
            Type of aspect importance/attributions (default is `'default'`, which means
            the use of simplified LIME method).
        N : int, optional
            Number of observations that will be sampled from the `explainer.data` attribute
            before the calculation of aspect importance (default is `2000`).
        B : int, optional
            Parameter specific for `type == 'shap'`. Number of random paths to calculate aspect
            attributions (default is `25`).
            NOTE: Ignored if `type` is not `'shap'`.
        sample_method : {'default', 'binom'}, optional
            Parameter specific for `type == 'default'`. Sampling method for creating binary matrix
            used as mask for replacing aspects in data (default is `'default'`, which means
            it randomly replaces one or two zeros per row; `'binom'` replaces random number of zeros
            per row).
            NOTE: Ignored if `type` is not `'default'`.
        f : int, optional
            Parameter specific for `type == 'default'` and `sample_method == 'binom'`. Parameter
            controlling average number of replaced zeros for binomial sampling (default is `2`).
            NOTE: Ignored if `type` is not `'default'` or `sample_method` is not `'binom'`.
        processes : int, optional
            Number of parallel processes to use in calculations. Iterated over `B`
            (default is `1`, which means no parallel computation).
        random_state : int, optional
            Set seed for random number generator (default is random seed).

        Returns
        -------
        PredictTriplot class object
            Explanation object containing the main result attribute and the plot method.
        """

        pt = PredictTriplot(type, N, B, sample_method, f, processes, random_state)

        pt.fit(self, new_observation)

        return pt

    def model_triplot(
        self,
        loss_function=None,
        type="variable_importance",
        N=1000,
        B=10,
        processes=1,
        random_state=None,
    ):
        """Calculate model-level hierarchical aspect importance

        Parameters
        ----------
        loss_function :  {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
            If string, then such loss function will be used to assess aspect importance
            (default is `'rmse'` or `'1-auc'`, depends on `explainer.model_type` attribute).
        type : {'variable_importance', 'ratio', 'difference'}, optional
            Type of transformation that will be applied to dropout loss
            (default is `'variable_importance'`, which is Permutational Variable Importance).
        N : int, optional
            Number of observations that will be sampled from the `explainer.data` attribute before
            the calculation of aspect importance. `None` means all `data` (default is `1000`).
        B : int, optional
            Number of permutation rounds to perform on each variable (default is `10`).
        processes : int, optional
            Number of parallel processes to use in calculations. Iterated over `B`
            (default is `1`, which means no parallel computation).
        random_state : int, optional
            Set seed for random number generator (default is random seed).

        Returns
        -------
        ModelTriplot class object
            Explanation object containing the main result attribute and the plot method.
        """

        
        loss_function = checks.check_method_loss_function(self.explainer, loss_function) # get proper loss_function for model_type
        mt = ModelTriplot(loss_function, type, N, B, processes, random_state)
        self._mt_params = {"loss_function": loss_function, "type": type, "N": N, "B": B} # save params for future calls of model_parts
        mt.fit(self)

        return mt

Methods

def get_aspects(self, h=0.5, n=None)
Expand source code Browse git
def get_aspects(self, h=0.5, n=None):
    from scipy.cluster.hierarchy import fcluster
    """Form aspects of variables from the hierarchical clustering

    Parameters
    ----------
    h : float, optional
        Threshold to apply when forming aspects, i.e., the minimum value of the dependency
        between the variables grouped in one aspect (default is `0.5`).
        NOTE: Ignored if `n` is not `None`.
    n : int, optional
        Maximum number of aspects to form 
        (default is `None`, which means the use of `h` parameter).

    Returns
    -------
    dict of lists
        Variables grouped in aspects, e.g. `{'aspect_1': ['x1', 'x2'], 'aspect_2': ['y1', 'y2']}`.
    """
    if n is None:
        aspect_label = fcluster(self.linkage_matrix, 1 - h, criterion="distance")
    else:
        aspect_label = fcluster(self.linkage_matrix, n, criterion="maxclust")
    aspects = pd.DataFrame(
        {"feature": self.depend_matrix.columns, "aspect": aspect_label}
    )
    aspects = aspects.groupby("aspect")["feature"].apply(list).reset_index()
    aspects_dict = {}

    # rename an aspect when there is a single variable in it
    i = 1
    for index, row in aspects.iterrows():
        if len(row["feature"]) > 1:
            aspects_dict[f"aspect_{i}"] = row["feature"]
            i += 1
        else:
            aspects_dict[row["feature"][0]] = row["feature"]

    return aspects_dict
def model_parts(self, variable_groups=None, h=0.5, loss_function=None, type='variable_importance', N=1000, B=10, processes=1, label=None, random_state=None)

Calculate model-level aspect importance

Parameters

variable_groups : dict of lists or None
Variables grouped in aspects to calculate their importance (default is None).
h : float, optional
Threshold to apply when forming aspects, i.e., the minimum value of the dependency between the variables grouped in one aspect (default is 0.5).
loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
If string, then such loss function will be used to assess aspect importance (default is 'rmse' or '1-auc', depends on explainer.model_type attribute).
type : {'variable_importance', 'ratio', 'difference'}, optional
Type of transformation that will be applied to dropout loss (default is 'variable_importance', which is Permutational Variable Importance).
N : int, optional
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance. None means all data (default is 1000).
B : int, optional
Number of permutation rounds to perform on each variable (default is 10).
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
label : str, optional
Name to appear in result and plots. Overrides default.
random_state : int, optional
Set seed for random number generator (default is random seed).

Returns

ModelAspectImportance class dalex.aspect.object
Explanation object containing the main result attribute and the plot method.
Expand source code Browse git
def model_parts(
    self,
    variable_groups=None,
    h=0.5,
    loss_function=None,
    type="variable_importance",
    N=1000,
    B=10,
    processes=1,
    label=None,
    random_state=None,
):
    """Calculate model-level aspect importance

    Parameters
    ----------
    variable_groups : dict of lists or None
        Variables grouped in aspects to calculate their importance (default is `None`).
    h : float, optional
        Threshold to apply when forming aspects, i.e., the minimum value of the dependency
        between the variables grouped in one aspect (default is `0.5`).
    loss_function :  {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
        If string, then such loss function will be used to assess aspect importance
        (default is `'rmse'` or `'1-auc'`, depends on `explainer.model_type` attribute).
    type : {'variable_importance', 'ratio', 'difference'}, optional
        Type of transformation that will be applied to dropout loss
        (default is `'variable_importance'`, which is Permutational Variable Importance).
    N : int, optional
        Number of observations that will be sampled from the `explainer.data` attribute before
        the calculation of aspect importance. `None` means all `data` (default is `1000`).
    B : int, optional
        Number of permutation rounds to perform on each variable (default is `10`).
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    label : str, optional
        Name to appear in result and plots. Overrides default.
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Returns
    -------
    ModelAspectImportance class object
        Explanation object containing the main result attribute and the plot method.
    """

    loss_function = checks.check_method_loss_function(self.explainer, loss_function)
    mai_result = None

    if variable_groups is None:
        variable_groups = self.get_aspects(h)

        # get results from triplot if it was precalculated with the same params
        if self._full_hierarchical_aspect_importance is not None:
            if (
                self._mt_params["loss_function"] == loss_function
                and self._mt_params["N"] == N
                and self._mt_params["B"] == B
                and self._mt_params["type"] == type
            ):
                h = min(1, h)
                h_selected = np.unique(
                    self._full_hierarchical_aspect_importance.loc[
                        self._full_hierarchical_aspect_importance.h >= h
                    ].h
                )[0]
                mai_result = self._full_hierarchical_aspect_importance.loc[
                    self._full_hierarchical_aspect_importance.h == h_selected
                ]

    ai = ModelAspectImportance(
        loss_function=loss_function,
        type=type,
        N=N,
        B=B,
        variable_groups=variable_groups,
        processes=processes,
        random_state=random_state,
        _depend_matrix=self.depend_matrix
    )

    # calculate if there was no results
    if mai_result is None:
        ai.fit(self.explainer)
    else: 
        mai_result = mai_result[
            [
                "aspect_name",
                "variable_names",
                "dropout_loss",
                "dropout_loss_change",
                "min_depend",
                "vars_min_depend",
                "label",
            ]
        ]
        ai.result = mai_result

    if label is not None:
        ai.result["label"] = label

    return ai
def model_triplot(self, loss_function=None, type='variable_importance', N=1000, B=10, processes=1, random_state=None)

Calculate model-level hierarchical aspect importance

Parameters

loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
If string, then such loss function will be used to assess aspect importance (default is 'rmse' or '1-auc', depends on explainer.model_type attribute).
type : {'variable_importance', 'ratio', 'difference'}, optional
Type of transformation that will be applied to dropout loss (default is 'variable_importance', which is Permutational Variable Importance).
N : int, optional
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance. None means all data (default is 1000).
B : int, optional
Number of permutation rounds to perform on each variable (default is 10).
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Returns

ModelTriplot class dalex.aspect.object
Explanation object containing the main result attribute and the plot method.
Expand source code Browse git
def model_triplot(
    self,
    loss_function=None,
    type="variable_importance",
    N=1000,
    B=10,
    processes=1,
    random_state=None,
):
    """Calculate model-level hierarchical aspect importance

    Parameters
    ----------
    loss_function :  {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
        If string, then such loss function will be used to assess aspect importance
        (default is `'rmse'` or `'1-auc'`, depends on `explainer.model_type` attribute).
    type : {'variable_importance', 'ratio', 'difference'}, optional
        Type of transformation that will be applied to dropout loss
        (default is `'variable_importance'`, which is Permutational Variable Importance).
    N : int, optional
        Number of observations that will be sampled from the `explainer.data` attribute before
        the calculation of aspect importance. `None` means all `data` (default is `1000`).
    B : int, optional
        Number of permutation rounds to perform on each variable (default is `10`).
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Returns
    -------
    ModelTriplot class object
        Explanation object containing the main result attribute and the plot method.
    """

    
    loss_function = checks.check_method_loss_function(self.explainer, loss_function) # get proper loss_function for model_type
    mt = ModelTriplot(loss_function, type, N, B, processes, random_state)
    self._mt_params = {"loss_function": loss_function, "type": type, "N": N, "B": B} # save params for future calls of model_parts
    mt.fit(self)

    return mt
def plot_dendrogram(self, title='Hierarchical clustering dendrogram', lines_interspace=20, rounding_function=<function round_>, digits=3, show=True)

Plot the hierarchical clustering dendrogram of variables

Parameters

title : str, optional
Title of the plot (default is "Hierarchical clustering dendrogram").
lines_interspace : float, optional
Interspace between lines of dendrogram in px (default is 20).
rounding_function : function, optional
A function that will be used for rounding numbers (default is np.around).
digits : int, optional
Number of decimal places (np.around) to round contributions. See rounding_function parameter (default is 3).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot_dendrogram(
    self,
    title="Hierarchical clustering dendrogram",
    lines_interspace=20,
    rounding_function=np.round,
    digits=3,
    show=True,
):
    """Plot the hierarchical clustering dendrogram of variables

    Parameters
    ----------
    title : str, optional
        Title of the plot (default is "Hierarchical clustering dendrogram").
    lines_interspace : float, optional
        Interspace between lines of dendrogram in px (default is `20`).
    rounding_function : function, optional
        A function that will be used for rounding numbers (default is `np.around`).
    digits : int, optional
        Number of decimal places (`np.around`) to round contributions.
        See `rounding_function` parameter (default is `3`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """
    m = len(self.depend_matrix.columns)
    plot_height = 78 + 71 + m * lines_interspace + (m + 1) * lines_interspace / 4
    fig = self._hierarchical_clustering_dendrogram
    fig = plot.add_text_and_tooltips_to_dendrogram(
        fig, self._dendrogram_aspects_ordered, rounding_function, digits
    )
    fig = plot._add_points_on_dendrogram_traces(fig)
    fig.update_layout(
        title={"text": title, "x": 0.15},
        yaxis={"automargin": True, "autorange": "reversed"},
        height=plot_height,
    )
    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
def predict_parts(self, new_observation, variable_groups=None, type='default', h=0.5, N=2000, B=25, n_aspects=None, sample_method='default', f=2, label=None, processes=1, random_state=None)

Calculate predict-level aspect importance

Parameters

new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
An observation for which a prediction needs to be explained.
variable_groups : dict of lists or None
Variables grouped in aspects to calculate their importance (default is None).
type : {'default', 'shap'}, optional
Type of aspect importance/attributions (default is 'default', which means the use of simplified LIME method).
h : float, optional
Threshold to apply when forming aspects, i.e., the minimum value of the dependency between the variables grouped in one aspect (default is 0.5).
N : int, optional
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance (default is 2000).
B : int, optional
Parameter specific for type == 'shap'. Number of random paths to calculate aspect attributions (default is 25). NOTE: Ignored if type is not 'shap'.
n_aspects : int, optional
Parameter specific for type == 'default'. Maximum number of non-zero importances, i.e. coefficients after lasso fitting (default is None, which means the linear regression is used). NOTE: Ignored if type is not 'default'.
sample_method : {'default', 'binom'}, optional
Parameter specific for type == 'default'. Sampling method for creating binary matrix used as mask for replacing aspects in sampled data (default is 'default', which means it randomly replaces one or two zeros per row; 'binom' replaces random number of zeros per row). NOTE: Ignored if type is not 'default'.
f : int, optional
Parameter specific for type == 'default' and sample_method == 'binom'. Parameter controlling average number of replaced zeros for binomial sampling (default is 2). NOTE: Ignored if type is not 'default' or sample_method is not 'binom'.
label : str, optional
Name to appear in result and plots. Overrides default.
processes : int, optional
Parameter specific for type == 'shap'. Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Returns

PredictAspectImportance class dalex.aspect.object
Explanation object containing the main result attribute and the plot method.
Expand source code Browse git
def predict_parts(
    self,
    new_observation,
    variable_groups=None,
    type="default",
    h=0.5,
    N=2000,
    B=25,
    n_aspects=None,
    sample_method="default",
    f=2,
    label=None,
    processes=1,
    random_state=None,
):
    """Calculate predict-level aspect importance

    Parameters
    ----------
    new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
        An observation for which a prediction needs to be explained.
    variable_groups : dict of lists or None
        Variables grouped in aspects to calculate their importance (default is `None`).
    type : {'default', 'shap'}, optional
        Type of aspect importance/attributions (default is `'default'`, which means
        the use of simplified LIME method).
    h : float, optional
        Threshold to apply when forming aspects, i.e., the minimum value of the dependency
        between the variables grouped in one aspect (default is `0.5`).
    N : int, optional
        Number of observations that will be sampled from the `explainer.data` attribute
        before the calculation of aspect importance (default is `2000`).
    B : int, optional
        Parameter specific for `type == 'shap'`. Number of random paths to calculate aspect
        attributions (default is `25`).
        NOTE: Ignored if `type` is not `'shap'`.
    n_aspects : int, optional
        Parameter specific for `type == 'default'`. Maximum number of non-zero importances, i.e.
        coefficients after lasso fitting (default is `None`, which means the linear regression is used).
        NOTE: Ignored if `type` is not `'default'`.
    sample_method : {'default', 'binom'}, optional
        Parameter specific for `type == 'default'`. Sampling method for creating binary matrix
        used as mask for replacing aspects in sampled data (default is `'default'`, which means
        it randomly replaces one or two zeros per row; `'binom'` replaces random number of zeros
        per row).
        NOTE: Ignored if `type` is not `'default'`.
    f : int, optional
        Parameter specific for `type == 'default'` and `sample_method == 'binom'`. Parameter
        controlling average number of replaced zeros for binomial sampling (default is `2`).
        NOTE: Ignored if `type` is not `'default'` or `sample_method` is not `'binom'`.
    label : str, optional
        Name to appear in result and plots. Overrides default.
    processes : int, optional
        Parameter specific for `type == 'shap'`. Number of parallel processes to use in calculations.
        Iterated over `B` (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Returns
    -------
    PredictAspectImportance class object
        Explanation object containing the main result attribute and the plot method.
    """

    if variable_groups is None:
        variable_groups = self.get_aspects(h)

    pai = PredictAspectImportance(
        variable_groups,
        type,
        N,
        B,
        n_aspects,
        sample_method,
        f,
        self.depend_method,
        self.corr_method,
        self.agg_method,
        processes,
        random_state,
        _depend_matrix=self.depend_matrix
    )

    pai.fit(self.explainer, new_observation)

    if label is not None:
        pai.result["label"] = label

    return pai
def predict_triplot(self, new_observation, type='default', N=2000, B=25, sample_method='default', f=2, processes=1, random_state=None)

Calculate predict-level hierarchical aspect importance

Parameters

new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
An observation for which a prediction needs to be explained.
type : {'default', 'shap'}, optional
Type of aspect importance/attributions (default is 'default', which means the use of simplified LIME method).
N : int, optional
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance (default is 2000).
B : int, optional
Parameter specific for type == 'shap'. Number of random paths to calculate aspect attributions (default is 25). NOTE: Ignored if type is not 'shap'.
sample_method : {'default', 'binom'}, optional
Parameter specific for type == 'default'. Sampling method for creating binary matrix used as mask for replacing aspects in data (default is 'default', which means it randomly replaces one or two zeros per row; 'binom' replaces random number of zeros per row). NOTE: Ignored if type is not 'default'.
f : int, optional
Parameter specific for type == 'default' and sample_method == 'binom'. Parameter controlling average number of replaced zeros for binomial sampling (default is 2). NOTE: Ignored if type is not 'default' or sample_method is not 'binom'.
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Returns

PredictTriplot class dalex.aspect.object
Explanation object containing the main result attribute and the plot method.
Expand source code Browse git
def predict_triplot(
    self,
    new_observation,
    type="default",
    N=2000,
    B=25,
    sample_method="default",
    f=2,
    processes=1,
    random_state=None,
):
    """Calculate predict-level hierarchical aspect importance

    Parameters
    ----------
    new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
        An observation for which a prediction needs to be explained.
    type : {'default', 'shap'}, optional
        Type of aspect importance/attributions (default is `'default'`, which means
        the use of simplified LIME method).
    N : int, optional
        Number of observations that will be sampled from the `explainer.data` attribute
        before the calculation of aspect importance (default is `2000`).
    B : int, optional
        Parameter specific for `type == 'shap'`. Number of random paths to calculate aspect
        attributions (default is `25`).
        NOTE: Ignored if `type` is not `'shap'`.
    sample_method : {'default', 'binom'}, optional
        Parameter specific for `type == 'default'`. Sampling method for creating binary matrix
        used as mask for replacing aspects in data (default is `'default'`, which means
        it randomly replaces one or two zeros per row; `'binom'` replaces random number of zeros
        per row).
        NOTE: Ignored if `type` is not `'default'`.
    f : int, optional
        Parameter specific for `type == 'default'` and `sample_method == 'binom'`. Parameter
        controlling average number of replaced zeros for binomial sampling (default is `2`).
        NOTE: Ignored if `type` is not `'default'` or `sample_method` is not `'binom'`.
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Returns
    -------
    PredictTriplot class object
        Explanation object containing the main result attribute and the plot method.
    """

    pt = PredictTriplot(type, N, B, sample_method, f, processes, random_state)

    pt.fit(self, new_observation)

    return pt
class ModelAspectImportance (variable_groups, loss_function=None, type='variable_importance', N=1000, B=10, depend_method='assoc', corr_method='spearman', agg_method='max', processes=1, random_state=None, **kwargs)

Calculate model-level aspect importance

Parameters

variable_groups : dict of lists
Variables grouped in aspects to calculate their importance.
loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
If string, then such loss function will be used to assess aspect importance (default is 'rmse' or '1-auc', depends on explainer.model_type attribute).
type : {'variable_importance', 'ratio', 'difference'}, optional
Type of transformation that will be applied to dropout loss (default is 'variable_importance', which is Permutational Variable Importance).
N : int, optional
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance. None means all data (default is 1000).
B : int, optional
Number of permutation rounds to perform on each variable (default is 10).
depend_method : {'assoc', 'pps'} or function, optional
The method of calculating the dependencies between variables (i.e. the dependency matrix). Default is 'assoc', which means the use of statistical association (correlation coefficient, Cramér's V and eta-quared); 'pps' stands for Power Predictive Score. NOTE: When a function is passed, it is called with the data and it must return a symmetric dependency matrix (pd.DataFrame with variable names as columns and rows).
corr_method : {'spearman', 'pearson', 'kendall'}, optional
The method of calculating correlation between numerical variables (default is 'spearman'). NOTE: Ignored if depend_method is not 'assoc'.
agg_method : {'max', 'min', 'avg'}, optional
The method of aggregating the PPS values for pairs of variables (default is 'max'). NOTE: Ignored if depend_method is not 'pps'.
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
variable_groups : dict of lists
Variables grouped in aspects to calculate their importance.
loss_function : function
Loss function used to assess the variable importance.
type : {'variable_importance', 'ratio', 'difference'}
Type of transformation that will be applied to dropout loss.
N : int
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance.
B : int
Number of permutation rounds to perform on each variable.
depend_method : {'assoc', 'pps'}
The method of calculating the dependencies between variables.
corr_method : {'spearman', 'pearson', 'kendall'}
The method of calculating correlation between numerical variables.
agg_method : {'max', 'min', 'avg'}
The method of aggregating the PPS values for pairs of variables.
processes : int
Number of parallel processes to use in calculations. Iterated over B.
random_state : int
Set seed for random number generator.
Expand source code Browse git
class ModelAspectImportance(VariableImportance):
    """Calculate model-level aspect importance

    Parameters
    ----------
    variable_groups : dict of lists 
        Variables grouped in aspects to calculate their importance.
    loss_function :  {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
        If string, then such loss function will be used to assess aspect importance
        (default is `'rmse'` or `'1-auc'`, depends on `explainer.model_type` attribute).
    type : {'variable_importance', 'ratio', 'difference'}, optional
        Type of transformation that will be applied to dropout loss
        (default is `'variable_importance'`, which is Permutational Variable Importance).
    N : int, optional
        Number of observations that will be sampled from the `explainer.data` attribute before
        the calculation of aspect importance. `None` means all `data` (default is `1000`).
    B : int, optional
        Number of permutation rounds to perform on each variable (default is `10`).
    depend_method: {'assoc', 'pps'} or function, optional
        The method of calculating the dependencies between variables (i.e. the dependency 
        matrix). Default is `'assoc'`, which means the use of statistical association 
        (correlation coefficient, Cramér's V and eta-quared); 
        `'pps'` stands for Power Predictive Score.
        NOTE: When a function is passed, it is called with the `data` and it 
        must return a symmetric dependency matrix (`pd.DataFrame` with variable names as 
        columns and rows).
    corr_method : {'spearman', 'pearson', 'kendall'}, optional
        The method of calculating correlation between numerical variables 
        (default is `'spearman'`).
        NOTE: Ignored if `depend_method` is not `'assoc'`.
    agg_method : {'max', 'min', 'avg'}, optional
        The method of aggregating the PPS values for pairs of variables 
        (default is `'max'`).
        NOTE: Ignored if `depend_method` is not `'pps'`. 
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    variable_groups : dict of lists 
        Variables grouped in aspects to calculate their importance. 
    loss_function : function
        Loss function used to assess the variable importance.
    type : {'variable_importance', 'ratio', 'difference'}
        Type of transformation that will be applied to dropout loss.
    N : int
        Number of observations that will be sampled from the `explainer.data` attribute before
        the calculation of aspect importance. 
    B : int
        Number of permutation rounds to perform on each variable.
    depend_method : {'assoc', 'pps'}
        The method of calculating the dependencies between variables.
    corr_method : {'spearman', 'pearson', 'kendall'}
        The method of calculating correlation between numerical variables.
    agg_method : {'max', 'min', 'avg'}
        The method of aggregating the PPS values for pairs of variables.
    processes : int
        Number of parallel processes to use in calculations. Iterated over `B`.
    random_state : int
        Set seed for random number generator.
    """
    def __init__(
        self,
        variable_groups,
        loss_function=None,
        type="variable_importance",
        N=1000,
        B=10,
        depend_method="assoc",
        corr_method="spearman",
        agg_method="max",
        processes=1,
        random_state=None,
        **kwargs
    ):
        super().__init__(
            loss_function,
            type,
            N,
            B,
            None,
            variable_groups,
            True,
            processes,
            random_state,
        )
        _depend_method, _corr_method, _agg_method = checks.check_method_depend(depend_method, corr_method, agg_method)
        self.result = pd.DataFrame()
        self._depend_matrix = None
        if "_depend_matrix" in kwargs:
            self._depend_matrix = kwargs.get("_depend_matrix")
        self.depend_method = _depend_method
        self.corr_method = _corr_method
        self.agg_method = _agg_method
        self.loss_function = loss_function

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self, explainer):
        """Calculate the result of explanation
        Fit method makes calculations in place and changes the attributes.

        Parameters
        ----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.
        
        Returns
        -----------
        None
        """
        _loss_function = checks.check_method_loss_function(explainer, self.loss_function)
        self.loss_function = checks.check_loss_function(_loss_function)

        super().fit(explainer)

        self.result["variable_names"] = self.result["variable"].map(self.variable_groups)
        baseline = self.result[self.result["variable"] == "_full_model_"]["dropout_loss"].values[0]
        self.result = self.result.assign(
            dropout_loss_change=lambda x: x["dropout_loss"] - baseline
        )
        self.result = self.result.rename(columns={"variable": "aspect_name"})
        self.result.insert(4, "min_depend", None)
        self.result.insert(5, "vars_min_depend", None)
        # if there is _depend_matrix in kwargs (called from Aspect object) 
        if self._depend_matrix is not None:
            vars_min_depend, min_depend = get_min_depend_from_matrix(self._depend_matrix, 
                    self.result.variable_names
                )
        else:
            vars_min_depend, min_depend = calculate_min_depend(
                self.result.variable_names, 
                explainer.data,
                self.depend_method,
                self.corr_method,
                self.agg_method,
            )

        self.result["min_depend"] = min_depend
        self.result["vars_min_depend"] = vars_min_depend

        self.result = self.result[
            [
                "aspect_name",
                "variable_names",
                "dropout_loss",
                "dropout_loss_change",
                "min_depend",
                "vars_min_depend",
                "label",
            ]
        ]
      

    def plot(
        self,
        objects=None,
        max_aspects=10,
        show_variable_names=True,
        digits=3,
        rounding_function=np.around,
        bar_width=25,
        split=("model", "aspect"),
        title="Model Aspect Importance",
        vertical_spacing=None,
        show=True,
    ):
        """Plot the Model Aspect Importance explanation.

        Parameters
        -----------
        objects : ModelAspectImportance object or array_like of ModelAspectImportance objects
            Additional objects to plot in subplots (default is `None`).
        max_aspects : int, optional
            Maximum number of aspects that will be presented for for each subplot
            (default is `10`).
        show_variable_names : bool, optional
            `True` shows names of variables grouped in aspects; `False` shows names of aspects
            (default is `True`).
        digits : int, optional
            Number of decimal places (`np.around`) to round contributions.
            See `rounding_function` parameter (default is `3`).
        rounding_function : function, optional
            A function that will be used for rounding numbers (default is `np.around`).
        bar_width : float, optional
            Width of bars in px (default is `25`).
        split : {'model', 'aspect'}, optional
            Split the subplots by model or aspect (default is `'model'`).
        title : str, optional
            Title of the plot (default is `"Model Aspect Importance"`).
        vertical_spacing : float <0, 1>, optional
            Ratio of vertical space between the plots (default is `0.2/number of rows`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -----------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        if isinstance(split, tuple):
            split = split[0]

        if split not in ("model", "aspect"):
            raise TypeError("split should be 'model' or 'aspect'")

        # are there any other objects to plot?
        if objects is None:
            n = 1
            _result_df = self.result.copy()
            if split == "aspect":  # force split by model if only one explainer
                split = "model"
        elif isinstance(
            objects, self.__class__
        ):  # allow for objects to be a single element
            n = 2
            _result_df = pd.concat([self.result.copy(), objects.result.copy()])
        elif isinstance(objects, (list, tuple)):  # objects as tuple or array
            n = len(objects) + 1
            _result_df = self.result.copy()
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _result_df = pd.concat([_result_df, ob.result.copy()])
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        dl = _result_df.loc[
            _result_df.aspect_name != "_baseline_", "dropout_loss"
        ].to_numpy()
        min_max_margin = dl.ptp() * 0.15
        min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

        # take out full model
        best_fits = _result_df[_result_df.aspect_name == "_full_model_"]

        # this produces dropout_loss_x and dropout_loss_y columns
        _result_df = _result_df.merge(
            best_fits[["label", "dropout_loss"]], how="left", on="label"
        )
        # remove full_model and baseline
        _result_df = _result_df[
            (_result_df.aspect_name != "_full_model_")
            & (_result_df.aspect_name != "_baseline_")
        ]
        _result_df = _result_df[
            [
                "label",
                "aspect_name",
                "dropout_loss_x",
                "dropout_loss_y",
                "variable_names",
                "min_depend",
                "vars_min_depend",
            ]
        ].rename(
            columns={
                "dropout_loss_x": "dropout_loss",
                "dropout_loss_y": "full_model",
            }
        )
        # calculate order of bars or variable plots (split = 'aspect')
        # get variable permutation
        perm = (
            _result_df[["aspect_name", "dropout_loss"]]
            .groupby("aspect_name")
            .mean()
            .reset_index()
            .sort_values("dropout_loss", ascending=False)
            .aspect_name.values
        )
        model_names = _result_df["label"].unique().tolist()

        if len(model_names) != n:
            raise ValueError("label must be unique for each model")

        plot_height = 78 + 71

        colors = _theme.get_default_colors(n, "bar")

        if vertical_spacing is None:
            vertical_spacing = 0.2 / n

        # init plot
        fig = make_subplots(
            rows=n,
            cols=1,
            shared_xaxes=True,
            vertical_spacing=vertical_spacing,
            x_title="drop-out loss",
            subplot_titles=model_names,
        )
        if split == "model":
            # split df by model
            df_list = [v for k, v in _result_df.groupby("label", sort=False)]

            for i, df in enumerate(df_list):
                m = df.shape[0]
                if max_aspects is not None and max_aspects < m:
                    m = max_aspects

                # take only m variables (for max_aspects)
                # sort rows of df by variable permutation and drop unused variables
                df = (
                    df.sort_values("dropout_loss")
                    .tail(m)
                    .set_index("aspect_name")
                    .reindex(perm)
                    .dropna()
                    .reset_index()
                )

                baseline = df.iloc[0, df.columns.get_loc("full_model")]

                df = df.assign(difference=lambda x: x["dropout_loss"] - baseline)

                lt = df.difference.apply(
                    lambda val: "+" + str(rounding_function(np.abs(val), digits))
                    if val > 0
                    else str(rounding_function(np.abs(val), digits))
                )
                tt = df.apply(
                    lambda row: plot.tooltip_text(
                        row, rounding_function, digits
                    ),
                    axis=1,
                )
               
                df = df.assign(label_text=lt, tooltip_text=tt)

                fig.add_shape(
                    type="line",
                    x0=baseline,
                    x1=baseline,
                    y0=-1,
                    y1=m,
                    yref="paper",
                    xref="x",
                    line={"color": "#371ea3", "width": 1.5, "dash": "dot"},
                    row=i + 1,
                    col=1,
                )

                if show_variable_names:
                    y_axis_ticks = [
                        ", ".join(variables_list)
                        for variables_list in df["variable_names"]
                    ]
                else:
                    y_axis_ticks = df["aspect_name"]

                fig.add_bar(
                    orientation="h",
                    y=y_axis_ticks,
                    x=df["difference"].tolist(),
                    textposition="outside",
                    text=df["label_text"].tolist(),
                    marker_color=colors[i],
                    base=baseline,
                    hovertext=df["tooltip_text"].tolist(),
                    hoverinfo="text",
                    hoverlabel={"bgcolor": "rgba(0,0,0,0.8)"},
                    showlegend=False,
                    row=i + 1,
                    col=1,
                )

                fig.update_yaxes(
                    {
                        "type": "category",
                        "autorange": "reversed",
                        "gridwidth": 2,
                        "automargin": True,
                        "ticks": "outside",
                        "tickcolor": "white",
                        "ticklen": 10,
                        "fixedrange": True,
                    },
                    row=i + 1,
                    col=1,
                )

                fig.update_xaxes(
                    {
                        "type": "linear",
                        "gridwidth": 2,
                        "zeroline": False,
                        "automargin": True,
                        "ticks": "outside",
                        "tickcolor": "white",
                        "ticklen": 3,
                        "fixedrange": True,
                    },
                    row=i + 1,
                    col=1,
                )

                plot_height += m * bar_width + (m + 1) * bar_width / 4 + 30
        elif split == "aspect":
            # split df by aspect
            df_list = [v for k, v in _result_df.groupby("aspect_name", sort=False)]

            n = len(df_list)
            if max_aspects is not None and max_aspects < n:
                n = max_aspects

            if vertical_spacing is None:
                vertical_spacing = 0.2 / n
            # init plot
            variable_names = perm[0:n]
            fig = make_subplots(
                rows=n,
                cols=1,
                shared_xaxes=True,
                vertical_spacing=vertical_spacing,
                x_title="drop-out loss",
                subplot_titles=variable_names,
            )

            df_dict = {e.aspect_name.array[0]: e for e in df_list}

            # take only n=max_aspects elements from df_dict
            for i in range(n):
                df = df_dict[perm[i]]
                m = df.shape[0]

                baseline = 0

                df = df.assign(difference=lambda x: x["dropout_loss"] - x["full_model"])

                lt = df.difference.apply(
                    lambda val: "+" + str(rounding_function(np.abs(val), digits))
                    if val > 0
                    else str(rounding_function(np.abs(val), digits))
                )
                tt = df.apply(
                    lambda row: plot.tooltip_text(row, rounding_function, digits),
                    axis=1,
                )
                df = df.assign(label_text=lt, tooltip_text=tt)

                fig.add_shape(
                    type="line",
                    x0=baseline,
                    x1=baseline,
                    y0=-1,
                    y1=m,
                    yref="paper",
                    xref="x",
                    line={"color": "#371ea3", "width": 1.5, "dash": "dot"},
                    row=i + 1,
                    col=1,
                )

                fig.add_bar(
                    orientation="h",
                    y=df["label"].tolist(),
                    x=df["dropout_loss"].tolist(),
                    textposition="outside",
                    text=df["label_text"].tolist(),
                    marker_color=colors,
                    base=baseline,
                    hovertext=df["tooltip_text"].tolist(),
                    hoverinfo="text",
                    hoverlabel={"bgcolor": "rgba(0,0,0,0.8)"},
                    showlegend=False,
                    row=i + 1,
                    col=1,
                )

                fig.update_yaxes(
                    {
                        "type": "category",
                        "autorange": "reversed",
                        "gridwidth": 2,
                        "automargin": True,
                        "ticks": "outside",
                        "tickcolor": "white",
                        "ticklen": 10,
                        "dtick": 1,
                        "fixedrange": True,
                    },
                    row=i + 1,
                    col=1,
                )

                fig.update_xaxes(
                    {
                        "type": "linear",
                        "gridwidth": 2,
                        "zeroline": False,
                        "automargin": True,
                        "ticks": "outside",
                        "tickcolor": "white",
                        "ticklen": 3,
                        "fixedrange": True,
                    },
                    row=i + 1,
                    col=1,
                )

                plot_height += m * bar_width + (m + 1) * bar_width / 4

        plot_height += (n - 1) * 70

        fig.update_xaxes({"range": min_max})
        fig.update_layout(
            title_text=title,
            title_x=0.15,
            font={"color": "#371ea3"},
            template="none",
            height=plot_height,
            margin={"t": 78, "b": 71, "r": 30},
        )

        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex.model_explanations._variable_importance.object.VariableImportance
  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, explainer)

Calculate the result of explanation Fit method makes calculations in place and changes the attributes.

Parameters

explainer : Explainer dalex.aspect.object
Model wrapper created using the Explainer class.

Returns

None
 
Expand source code Browse git
def fit(self, explainer):
    """Calculate the result of explanation
    Fit method makes calculations in place and changes the attributes.

    Parameters
    ----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.
    
    Returns
    -----------
    None
    """
    _loss_function = checks.check_method_loss_function(explainer, self.loss_function)
    self.loss_function = checks.check_loss_function(_loss_function)

    super().fit(explainer)

    self.result["variable_names"] = self.result["variable"].map(self.variable_groups)
    baseline = self.result[self.result["variable"] == "_full_model_"]["dropout_loss"].values[0]
    self.result = self.result.assign(
        dropout_loss_change=lambda x: x["dropout_loss"] - baseline
    )
    self.result = self.result.rename(columns={"variable": "aspect_name"})
    self.result.insert(4, "min_depend", None)
    self.result.insert(5, "vars_min_depend", None)
    # if there is _depend_matrix in kwargs (called from Aspect object) 
    if self._depend_matrix is not None:
        vars_min_depend, min_depend = get_min_depend_from_matrix(self._depend_matrix, 
                self.result.variable_names
            )
    else:
        vars_min_depend, min_depend = calculate_min_depend(
            self.result.variable_names, 
            explainer.data,
            self.depend_method,
            self.corr_method,
            self.agg_method,
        )

    self.result["min_depend"] = min_depend
    self.result["vars_min_depend"] = vars_min_depend

    self.result = self.result[
        [
            "aspect_name",
            "variable_names",
            "dropout_loss",
            "dropout_loss_change",
            "min_depend",
            "vars_min_depend",
            "label",
        ]
    ]
def plot(self, objects=None, max_aspects=10, show_variable_names=True, digits=3, rounding_function=<function around>, bar_width=25, split=('model', 'aspect'), title='Model Aspect Importance', vertical_spacing=None, show=True)

Plot the Model Aspect Importance explanation.

Parameters

objects : ModelAspectImportance dalex.aspect.object or array_like of ModelAspectImportance objects
Additional objects to plot in subplots (default is None).
max_aspects : int, optional
Maximum number of aspects that will be presented for for each subplot (default is 10).
show_variable_names : bool, optional
True shows names of variables grouped in aspects; False shows names of aspects (default is True).
digits : int, optional
Number of decimal places (np.around) to round contributions. See rounding_function parameter (default is 3).
rounding_function : function, optional
A function that will be used for rounding numbers (default is np.around).
bar_width : float, optional
Width of bars in px (default is 25).
split : {'model', 'aspect'}, optional
Split the subplots by model or aspect (default is 'model').
title : str, optional
Title of the plot (default is "Model Aspect Importance").
vertical_spacing : float <0, 1>, optional
Ratio of vertical space between the plots (default is 0.2/number of rows).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(
    self,
    objects=None,
    max_aspects=10,
    show_variable_names=True,
    digits=3,
    rounding_function=np.around,
    bar_width=25,
    split=("model", "aspect"),
    title="Model Aspect Importance",
    vertical_spacing=None,
    show=True,
):
    """Plot the Model Aspect Importance explanation.

    Parameters
    -----------
    objects : ModelAspectImportance object or array_like of ModelAspectImportance objects
        Additional objects to plot in subplots (default is `None`).
    max_aspects : int, optional
        Maximum number of aspects that will be presented for for each subplot
        (default is `10`).
    show_variable_names : bool, optional
        `True` shows names of variables grouped in aspects; `False` shows names of aspects
        (default is `True`).
    digits : int, optional
        Number of decimal places (`np.around`) to round contributions.
        See `rounding_function` parameter (default is `3`).
    rounding_function : function, optional
        A function that will be used for rounding numbers (default is `np.around`).
    bar_width : float, optional
        Width of bars in px (default is `25`).
    split : {'model', 'aspect'}, optional
        Split the subplots by model or aspect (default is `'model'`).
    title : str, optional
        Title of the plot (default is `"Model Aspect Importance"`).
    vertical_spacing : float <0, 1>, optional
        Ratio of vertical space between the plots (default is `0.2/number of rows`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -----------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    if isinstance(split, tuple):
        split = split[0]

    if split not in ("model", "aspect"):
        raise TypeError("split should be 'model' or 'aspect'")

    # are there any other objects to plot?
    if objects is None:
        n = 1
        _result_df = self.result.copy()
        if split == "aspect":  # force split by model if only one explainer
            split = "model"
    elif isinstance(
        objects, self.__class__
    ):  # allow for objects to be a single element
        n = 2
        _result_df = pd.concat([self.result.copy(), objects.result.copy()])
    elif isinstance(objects, (list, tuple)):  # objects as tuple or array
        n = len(objects) + 1
        _result_df = self.result.copy()
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _result_df = pd.concat([_result_df, ob.result.copy()])
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    dl = _result_df.loc[
        _result_df.aspect_name != "_baseline_", "dropout_loss"
    ].to_numpy()
    min_max_margin = dl.ptp() * 0.15
    min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

    # take out full model
    best_fits = _result_df[_result_df.aspect_name == "_full_model_"]

    # this produces dropout_loss_x and dropout_loss_y columns
    _result_df = _result_df.merge(
        best_fits[["label", "dropout_loss"]], how="left", on="label"
    )
    # remove full_model and baseline
    _result_df = _result_df[
        (_result_df.aspect_name != "_full_model_")
        & (_result_df.aspect_name != "_baseline_")
    ]
    _result_df = _result_df[
        [
            "label",
            "aspect_name",
            "dropout_loss_x",
            "dropout_loss_y",
            "variable_names",
            "min_depend",
            "vars_min_depend",
        ]
    ].rename(
        columns={
            "dropout_loss_x": "dropout_loss",
            "dropout_loss_y": "full_model",
        }
    )
    # calculate order of bars or variable plots (split = 'aspect')
    # get variable permutation
    perm = (
        _result_df[["aspect_name", "dropout_loss"]]
        .groupby("aspect_name")
        .mean()
        .reset_index()
        .sort_values("dropout_loss", ascending=False)
        .aspect_name.values
    )
    model_names = _result_df["label"].unique().tolist()

    if len(model_names) != n:
        raise ValueError("label must be unique for each model")

    plot_height = 78 + 71

    colors = _theme.get_default_colors(n, "bar")

    if vertical_spacing is None:
        vertical_spacing = 0.2 / n

    # init plot
    fig = make_subplots(
        rows=n,
        cols=1,
        shared_xaxes=True,
        vertical_spacing=vertical_spacing,
        x_title="drop-out loss",
        subplot_titles=model_names,
    )
    if split == "model":
        # split df by model
        df_list = [v for k, v in _result_df.groupby("label", sort=False)]

        for i, df in enumerate(df_list):
            m = df.shape[0]
            if max_aspects is not None and max_aspects < m:
                m = max_aspects

            # take only m variables (for max_aspects)
            # sort rows of df by variable permutation and drop unused variables
            df = (
                df.sort_values("dropout_loss")
                .tail(m)
                .set_index("aspect_name")
                .reindex(perm)
                .dropna()
                .reset_index()
            )

            baseline = df.iloc[0, df.columns.get_loc("full_model")]

            df = df.assign(difference=lambda x: x["dropout_loss"] - baseline)

            lt = df.difference.apply(
                lambda val: "+" + str(rounding_function(np.abs(val), digits))
                if val > 0
                else str(rounding_function(np.abs(val), digits))
            )
            tt = df.apply(
                lambda row: plot.tooltip_text(
                    row, rounding_function, digits
                ),
                axis=1,
            )
           
            df = df.assign(label_text=lt, tooltip_text=tt)

            fig.add_shape(
                type="line",
                x0=baseline,
                x1=baseline,
                y0=-1,
                y1=m,
                yref="paper",
                xref="x",
                line={"color": "#371ea3", "width": 1.5, "dash": "dot"},
                row=i + 1,
                col=1,
            )

            if show_variable_names:
                y_axis_ticks = [
                    ", ".join(variables_list)
                    for variables_list in df["variable_names"]
                ]
            else:
                y_axis_ticks = df["aspect_name"]

            fig.add_bar(
                orientation="h",
                y=y_axis_ticks,
                x=df["difference"].tolist(),
                textposition="outside",
                text=df["label_text"].tolist(),
                marker_color=colors[i],
                base=baseline,
                hovertext=df["tooltip_text"].tolist(),
                hoverinfo="text",
                hoverlabel={"bgcolor": "rgba(0,0,0,0.8)"},
                showlegend=False,
                row=i + 1,
                col=1,
            )

            fig.update_yaxes(
                {
                    "type": "category",
                    "autorange": "reversed",
                    "gridwidth": 2,
                    "automargin": True,
                    "ticks": "outside",
                    "tickcolor": "white",
                    "ticklen": 10,
                    "fixedrange": True,
                },
                row=i + 1,
                col=1,
            )

            fig.update_xaxes(
                {
                    "type": "linear",
                    "gridwidth": 2,
                    "zeroline": False,
                    "automargin": True,
                    "ticks": "outside",
                    "tickcolor": "white",
                    "ticklen": 3,
                    "fixedrange": True,
                },
                row=i + 1,
                col=1,
            )

            plot_height += m * bar_width + (m + 1) * bar_width / 4 + 30
    elif split == "aspect":
        # split df by aspect
        df_list = [v for k, v in _result_df.groupby("aspect_name", sort=False)]

        n = len(df_list)
        if max_aspects is not None and max_aspects < n:
            n = max_aspects

        if vertical_spacing is None:
            vertical_spacing = 0.2 / n
        # init plot
        variable_names = perm[0:n]
        fig = make_subplots(
            rows=n,
            cols=1,
            shared_xaxes=True,
            vertical_spacing=vertical_spacing,
            x_title="drop-out loss",
            subplot_titles=variable_names,
        )

        df_dict = {e.aspect_name.array[0]: e for e in df_list}

        # take only n=max_aspects elements from df_dict
        for i in range(n):
            df = df_dict[perm[i]]
            m = df.shape[0]

            baseline = 0

            df = df.assign(difference=lambda x: x["dropout_loss"] - x["full_model"])

            lt = df.difference.apply(
                lambda val: "+" + str(rounding_function(np.abs(val), digits))
                if val > 0
                else str(rounding_function(np.abs(val), digits))
            )
            tt = df.apply(
                lambda row: plot.tooltip_text(row, rounding_function, digits),
                axis=1,
            )
            df = df.assign(label_text=lt, tooltip_text=tt)

            fig.add_shape(
                type="line",
                x0=baseline,
                x1=baseline,
                y0=-1,
                y1=m,
                yref="paper",
                xref="x",
                line={"color": "#371ea3", "width": 1.5, "dash": "dot"},
                row=i + 1,
                col=1,
            )

            fig.add_bar(
                orientation="h",
                y=df["label"].tolist(),
                x=df["dropout_loss"].tolist(),
                textposition="outside",
                text=df["label_text"].tolist(),
                marker_color=colors,
                base=baseline,
                hovertext=df["tooltip_text"].tolist(),
                hoverinfo="text",
                hoverlabel={"bgcolor": "rgba(0,0,0,0.8)"},
                showlegend=False,
                row=i + 1,
                col=1,
            )

            fig.update_yaxes(
                {
                    "type": "category",
                    "autorange": "reversed",
                    "gridwidth": 2,
                    "automargin": True,
                    "ticks": "outside",
                    "tickcolor": "white",
                    "ticklen": 10,
                    "dtick": 1,
                    "fixedrange": True,
                },
                row=i + 1,
                col=1,
            )

            fig.update_xaxes(
                {
                    "type": "linear",
                    "gridwidth": 2,
                    "zeroline": False,
                    "automargin": True,
                    "ticks": "outside",
                    "tickcolor": "white",
                    "ticklen": 3,
                    "fixedrange": True,
                },
                row=i + 1,
                col=1,
            )

            plot_height += m * bar_width + (m + 1) * bar_width / 4

    plot_height += (n - 1) * 70

    fig.update_xaxes({"range": min_max})
    fig.update_layout(
        title_text=title,
        title_x=0.15,
        font={"color": "#371ea3"},
        template="none",
        height=plot_height,
        margin={"t": 78, "b": 71, "r": 30},
    )

    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
class ModelTriplot (loss_function=None, type='variable_importance', N=1000, B=10, processes=1, random_state=None)

Calculate model-level hierarchical aspect importance

Parameters

loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
If string, then such loss function will be used to assess aspect importance (default is 'rmse' or '1-auc', depends on explainer.model_type attribute).
type : {'variable_importance', 'ratio', 'difference'}, optional
Type of transformation that will be applied to dropout loss (default is 'variable_importance', which is Permutational Variable Importance).
N : int, optional
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance. None means all data (default is 1000).
B : int, optional
Number of permutation rounds to perform on each variable (default is 10).
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
single_variable_importance : pd.DataFrame
Additional result attribute of an explanation (it contains information about the importance of individual variables).
loss_function : function
Loss function used to assess the variable importance.
type : {'variable_importance', 'ratio', 'difference'}
Type of transformation that will be applied to dropout loss.
N : int
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance. None means all data (default is 1000).
B : int
Number of permutation rounds to perform on each variable (default is 10).
processes : int
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int or None
Set seed for random number generator.

Notes

Expand source code Browse git
class ModelTriplot(Explanation):
    """Calculate model-level hierarchical aspect importance

    Parameters
    ----------
    loss_function :  {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
        If string, then such loss function will be used to assess aspect importance
        (default is `'rmse'` or `'1-auc'`, depends on `explainer.model_type` attribute).
    type : {'variable_importance', 'ratio', 'difference'}, optional
        Type of transformation that will be applied to dropout loss
        (default is `'variable_importance'`, which is Permutational Variable Importance).
    N : int, optional
        Number of observations that will be sampled from the `explainer.data` attribute before
        the calculation of aspect importance. `None` means all `data` (default is `1000`).
    B : int, optional
        Number of permutation rounds to perform on each variable (default is `10`).
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    single_variable_importance : pd.DataFrame
        Additional result attribute of an explanation (it contains information 
        about the importance of individual variables).
    loss_function : function
        Loss function used to assess the variable importance.
    type : {'variable_importance', 'ratio', 'difference'}
        Type of transformation that will be applied to dropout loss.
    N : int
        Number of observations that will be sampled from the `explainer.data` attribute before
        the calculation of aspect importance. `None` means all `data` (default is `1000`).
    B : int
        Number of permutation rounds to perform on each variable (default is `10`).
    processes : int
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int or None
        Set seed for random number generator.

    Notes
    -----
    - https://arxiv.org/abs/2104.03403
    """
    def __init__(
        self,
        loss_function=None,
        type="variable_importance",
        N=1000,
        B=10,
        processes=1,
        random_state=None,
    ):
        _B = checks.check_B(B)
        _type = checks.check_type(type)
        _random_state = checks.check_random_state(random_state)
        _processes = checks.check_processes(processes)

        self.loss_function = loss_function
        self.type = _type
        self.N = N
        self.B = _B
        self.random_state = _random_state
        self.processes = _processes
        self.result = pd.DataFrame()
        self.single_variable_importance = None
        self._hierarchical_clustering_dendrogram = None

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self, aspect):
        """Calculate the result of explanation
        Fit method makes calculations in place and changes the attributes.

        Parameters
        ----------
        aspect : Aspect object
            Explainer wrapper created using the Aspect class.
        
        Returns
        -----------
        None
        """
        _loss_function = checks.check_method_loss_function(aspect.explainer, self.loss_function)
        self.loss_function = checks.check_loss_function(_loss_function)

        self._hierarchical_clustering_dendrogram = aspect._hierarchical_clustering_dendrogram
        
        ## middle plot
        (
        self.result,
        aspect._full_hierarchical_aspect_importance
        ) = utils.calculate_model_hierarchical_importance(
            aspect,
            self.loss_function,
            self.type,
            self.N,
            self.B,
            self.processes,
            self.random_state
        )

        ## left plot data
        self.single_variable_importance, svi_full = utils.calculate_single_variable_importance(
            aspect,
            self.loss_function,
            self.type,
            self.N,
            self.B,
            self.processes,
            self.random_state,
        )
    
        aspect._full_hierarchical_aspect_importance = pd.concat(
            [aspect._full_hierarchical_aspect_importance, svi_full]
        )

    def plot(
        self,
        digits=3,
        rounding_function=np.around,
        show_change = True,
        bar_width=25,
        width=1500,
        vcolors=None,
        title="Model Triplot",
        widget=False,
        show=True
    ):
        """Plot the Model Triplot explanation (triplot visualization).

        Parameters
        ----------
        digits : int, optional
            Number of decimal places (`np.around`) to round contributions.
            See `rounding_function` parameter (default is `3`).
        rounding_function : function, optional
            A function that will be used for rounding numbers (default is `np.around`).
        show_change : bool, optional
            If `True` middle panel shows dropout loss change, otherwise dropout loss
            (default is `True`). 
        bar_width : float, optional
            Width of bars in px (default is `16`).
        width : float, optional
            Width of triplot in px (default is `1500`).
        vcolors : str, optional
            Color of bars (default is `"#46bac2"`).
        title : str, optional
            Title of the plot (default is `"Model Triplot"`).
        widget : bool, optional
            If `True` triplot interactive widget version is generated
            (default is `False`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object 
            (default is `True`).
            NOTE: Ignored if `widget` is `True`.

        Returns
        -------
        None or plotly.graph_objects.Figure or ipywidgets.HBox with plotly.graph_objs._figurewidget.FigureWidget
            Return figure that can be edited or saved. See `show` parameter.
        """
        _global_checks.global_check_import('kaleido', 'Model Triplot')
        ## right plot
        hierarchical_clustering_dendrogram_plot_without_annotations = (
            self._hierarchical_clustering_dendrogram
        )
        variables_order = list(
            hierarchical_clustering_dendrogram_plot_without_annotations.layout.yaxis.ticktext
        )
        m = len(variables_order)

        ## middle plot
        (
            hierarchical_importance_dendrogram_plot_without_annotations,
            updated_dendro_traces,
        ) = plot.plot_model_hierarchical_importance(
            hierarchical_clustering_dendrogram_plot_without_annotations,
            self.result,
            rounding_function,
            digits,
            show_change
        )

        hierarchical_clustering_dendrogram_plot = plot.add_text_to_dendrogram(
            hierarchical_clustering_dendrogram_plot_without_annotations,
            updated_dendro_traces,
            rounding_function, 
            digits,
            type="clustering",
        )

        hierarchical_importance_dendrogram_plot = plot.add_text_to_dendrogram(
            hierarchical_importance_dendrogram_plot_without_annotations,
            updated_dendro_traces,
            rounding_function, 
            digits,
            type="importance",
        )

        ## left plot
        fig = plot.plot_single_aspects_importance(
            self.single_variable_importance,
            variables_order,
            rounding_function,
            digits,
            vcolors
        )
        
        fig.layout["xaxis"]["range"] = (
            fig.layout["xaxis"]["range"][0],
            fig.layout["xaxis"]["range"][1] * 1.05,
        )
        y_vals = [-5 - i * 10 for i in range(m)]
        fig.data[0]["y"] = y_vals

        ## triplot
        min_x_imp, max_x_imp = np.Inf, -np.Inf
        for data in hierarchical_importance_dendrogram_plot["data"][::-1]:
            data["xaxis"] = "x2"
            data["hoverinfo"] = "text"
            data["line"] = {"color": "#46bac2", "width": 2}
            fig.add_trace(data)
            min_x_imp = np.min([min_x_imp, np.min(data["x"])])
            max_x_imp = np.max([max_x_imp, np.max(data["x"])])
        min_max_margin_imp = (max_x_imp - min_x_imp) * 0.15

        min_x_clust, max_x_clust = np.Inf, -np.Inf
        for data in hierarchical_clustering_dendrogram_plot["data"]:
            data["xaxis"] = "x3"
            data["hoverinfo"] = "text"
            data["line"] = {"color": "#46bac2", "width": 2}
            fig.add_trace(data)
            min_x_clust = np.min([min_x_clust, np.min(data["x"])])
            max_x_clust = np.max([max_x_clust, np.max(data["x"])])
        min_max_margin_clust = (max_x_clust - min_x_clust) * 0.15

        plot_height = 78 + 71 + m * bar_width + (m + 1) * bar_width / 4

        fig.update_layout(
            xaxis={
                "autorange": False,
                "domain": [0, 0.33],
                "mirror": False,
                "showgrid": False,
                "showline": False,
                "zeroline": False,
                "showticklabels": True,
                "ticks": "",
                "title_text": "Variable importance",
            },
            xaxis2={
                "domain": [0.33, 0.66],
                "mirror": False,
                "showgrid": False,
                "showline": False,
                "zeroline": False,
                "showticklabels": True,
                "tickvals": [0],
                "ticktext": [""],
                "ticks": "",
                "title_text": "Hierarchical aspect importance",
                "fixedrange": True,
                "autorange": False,
                "range": [
                    min_x_imp - min_max_margin_imp,
                    max_x_imp + min_max_margin_imp,
                ],
            },
            xaxis3={
                "domain": [0.66, 0.99],
                "mirror": False,
                "showgrid": False,
                "showline": False,
                "zeroline": False,
                "showticklabels": True,
                "tickvals": [0],
                "ticktext": [""],
                "ticks": "",
                "title_text": "Hierarchical clustering",
                "fixedrange": True,
                "autorange": False,
                "range": [
                    min_x_clust - min_max_margin_clust,
                    max_x_clust + min_max_margin_clust,
                ],
            },
            yaxis={
                "mirror": False,
                "ticks": "",
                "fixedrange": True,
                "gridwidth": 1,
                "type": "linear",
                "tickmode": "array",
                "tickvals": y_vals,
                "ticktext": variables_order,
            },
            title_text=title,
            title_x=0.5,
            font={"color": "#371ea3"},
            template="none",
            margin={"t": 78, "b": 71, "r": 30},
            width=width,
            height=plot_height,
            showlegend=False,
            hovermode="closest",

        )

        fig, middle_point = plot._add_points_on_dendrogram_traces(fig)

        ##################################################################

        if widget:
            _global_checks.global_check_import('ipywidgets', 'Model Triplot')
            from ipywidgets import HBox, Layout
            fig = go.FigureWidget(fig, layout={"autosize": True, "hoverdistance": 100})
            original_bar_colors = deepcopy([fig.data[0]["marker"]["color"]] * m)
            original_text_colors = deepcopy(list(fig.data[0]["textfont"]["color"]))
            k = len(fig.data)
            updated_dendro_traces_in_full_figure = list(
                np.array(updated_dendro_traces) + (k - 1) / 2 + 1
            ) + list((k - 1) / 2 - np.array(updated_dendro_traces))

            def _update_childs(x, y, fig, k, selected, selected_y_cord):
                for i in range(1, k):
                    if middle_point[i] == (x, y):
                        fig.data[i]["line"]["color"] = "#46bac2"
                        fig.data[i]["line"]["width"] = 3
                        fig.data[k - i]["line"]["color"] = "#46bac2"
                        fig.data[k - i]["line"]["width"] = 3
                        selected.append(i)
                        selected.append(k - i)
                        if (fig.data[i]["y"][0] + 5) % 10 == 0:
                            selected_y_cord.append((fig.data[i]["y"][0] + 5) // -10)
                        if (fig.data[i]["y"][-1] - 5) % 10 == 0:
                            selected_y_cord.append((fig.data[i]["y"][-1] + 5) // -10)
                        _update_childs(
                            fig.data[i]["x"][0],
                            fig.data[i]["y"][0],
                            fig,
                            k,
                            selected,
                            selected_y_cord,
                        )
                        _update_childs(
                            fig.data[i]["x"][-1],
                            fig.data[i]["y"][-1],
                            fig,
                            k,
                            selected,
                            selected_y_cord,
                        )

            def _update_trace(trace, points, selector):
                if len(points.point_inds) == 1:
                    selected_ind = points.trace_index
                    with fig.batch_update():
                        if max(fig.data[selected_ind]["x"]) in (max_x_clust, max_x_imp):
                            for i in range(1, k):
                                fig.data[i]["line"]["color"] = "#46bac2"
                                fig.data[i]["line"]["width"] = 2
                                fig.data[i]["textfont"]["color"] = "#371ea3"
                                fig.data[i]["textfont"]["size"] = 12
                            fig.data[0]["marker"]["color"] = original_bar_colors
                            fig.data[0]["textfont"]["color"] = original_text_colors
                        else:
                            selected = [selected_ind, k - selected_ind]
                            selected_y_cord = []
                            if (fig.data[selected_ind]["y"][0] - 5) % 10 == 0:
                                selected_y_cord.append(
                                    (fig.data[selected_ind]["y"][0] + 5) // -10
                                )
                            if (fig.data[selected_ind]["y"][-1] - 5) % 10 == 0:
                                selected_y_cord.append(
                                    (fig.data[selected_ind]["y"][-1] + 5) // -10
                                )
                            fig.data[selected_ind]["line"]["color"] = "#46bac2"
                            fig.data[selected_ind]["line"]["width"] = 3
                            fig.data[selected_ind]["textfont"]["color"] = "#371ea3"
                            fig.data[selected_ind]["textfont"]["size"] = 14
                            fig.data[k - selected_ind]["line"]["color"] = "#46bac2"
                            fig.data[k - selected_ind]["line"]["width"] = 3
                            fig.data[k - selected_ind]["textfont"]["color"] = "#371ea3"
                            fig.data[k - selected_ind]["textfont"]["size"] = 14
                            _update_childs(
                                fig.data[selected_ind]["x"][0],
                                fig.data[selected_ind]["y"][0],
                                fig,
                                k,
                                selected,
                                selected_y_cord,
                            )
                            _update_childs(
                                fig.data[selected_ind]["x"][-1],
                                fig.data[selected_ind]["y"][-1],
                                fig,
                                k,
                                selected,
                                selected_y_cord,
                            )
                            for i in range(1, k):
                                if i not in [selected_ind, k - selected_ind]:
                                    fig.data[i]["textfont"]["color"] = "#ceced9"
                                    fig.data[i]["textfont"]["size"] = 12
                                    if i not in selected:
                                        fig.data[i]["line"]["color"] = "#ceced9"
                                        fig.data[i]["line"]["width"] = 1

                            bars_colors_list = deepcopy(original_bar_colors)
                            text_colors_list = deepcopy(original_text_colors)
                            for i in range(m):
                                if i not in selected_y_cord:
                                    bars_colors_list[i] = "#ceced9"
                                    text_colors_list[i] = "#ceced9"
                            fig.data[0]["marker"]["color"] = bars_colors_list
                            fig.data[0]["textfont"]["color"] = text_colors_list

            for i in range(1, k):
                fig.data[i].on_click(_update_trace)
            return HBox([fig], layout=Layout(overflow='scroll', width=f'{fig.layout.width}px'))
        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, aspect)

Calculate the result of explanation Fit method makes calculations in place and changes the attributes.

Parameters

aspect : Aspect dalex.aspect.object
Explainer wrapper created using the Aspect class.

Returns

None
 
Expand source code Browse git
def fit(self, aspect):
    """Calculate the result of explanation
    Fit method makes calculations in place and changes the attributes.

    Parameters
    ----------
    aspect : Aspect object
        Explainer wrapper created using the Aspect class.
    
    Returns
    -----------
    None
    """
    _loss_function = checks.check_method_loss_function(aspect.explainer, self.loss_function)
    self.loss_function = checks.check_loss_function(_loss_function)

    self._hierarchical_clustering_dendrogram = aspect._hierarchical_clustering_dendrogram
    
    ## middle plot
    (
    self.result,
    aspect._full_hierarchical_aspect_importance
    ) = utils.calculate_model_hierarchical_importance(
        aspect,
        self.loss_function,
        self.type,
        self.N,
        self.B,
        self.processes,
        self.random_state
    )

    ## left plot data
    self.single_variable_importance, svi_full = utils.calculate_single_variable_importance(
        aspect,
        self.loss_function,
        self.type,
        self.N,
        self.B,
        self.processes,
        self.random_state,
    )

    aspect._full_hierarchical_aspect_importance = pd.concat(
        [aspect._full_hierarchical_aspect_importance, svi_full]
    )
def plot(self, digits=3, rounding_function=<function around>, show_change=True, bar_width=25, width=1500, vcolors=None, title='Model Triplot', widget=False, show=True)

Plot the Model Triplot explanation (triplot visualization).

Parameters

digits : int, optional
Number of decimal places (np.around) to round contributions. See rounding_function parameter (default is 3).
rounding_function : function, optional
A function that will be used for rounding numbers (default is np.around).
show_change : bool, optional
If True middle panel shows dropout loss change, otherwise dropout loss (default is True).
bar_width : float, optional
Width of bars in px (default is 16).
width : float, optional
Width of triplot in px (default is 1500).
vcolors : str, optional
Color of bars (default is "#46bac2").
title : str, optional
Title of the plot (default is "Model Triplot").
widget : bool, optional
If True triplot interactive widget version is generated (default is False).
show : bool, optional
True shows the plot; False returns the plotly Figure object (default is True). NOTE: Ignored if widget is True.

Returns

None or plotly.graph_objects.Figure or ipywidgets.HBox with plotly.graph_objs._figurewidget.FigureWidget
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(
    self,
    digits=3,
    rounding_function=np.around,
    show_change = True,
    bar_width=25,
    width=1500,
    vcolors=None,
    title="Model Triplot",
    widget=False,
    show=True
):
    """Plot the Model Triplot explanation (triplot visualization).

    Parameters
    ----------
    digits : int, optional
        Number of decimal places (`np.around`) to round contributions.
        See `rounding_function` parameter (default is `3`).
    rounding_function : function, optional
        A function that will be used for rounding numbers (default is `np.around`).
    show_change : bool, optional
        If `True` middle panel shows dropout loss change, otherwise dropout loss
        (default is `True`). 
    bar_width : float, optional
        Width of bars in px (default is `16`).
    width : float, optional
        Width of triplot in px (default is `1500`).
    vcolors : str, optional
        Color of bars (default is `"#46bac2"`).
    title : str, optional
        Title of the plot (default is `"Model Triplot"`).
    widget : bool, optional
        If `True` triplot interactive widget version is generated
        (default is `False`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object 
        (default is `True`).
        NOTE: Ignored if `widget` is `True`.

    Returns
    -------
    None or plotly.graph_objects.Figure or ipywidgets.HBox with plotly.graph_objs._figurewidget.FigureWidget
        Return figure that can be edited or saved. See `show` parameter.
    """
    _global_checks.global_check_import('kaleido', 'Model Triplot')
    ## right plot
    hierarchical_clustering_dendrogram_plot_without_annotations = (
        self._hierarchical_clustering_dendrogram
    )
    variables_order = list(
        hierarchical_clustering_dendrogram_plot_without_annotations.layout.yaxis.ticktext
    )
    m = len(variables_order)

    ## middle plot
    (
        hierarchical_importance_dendrogram_plot_without_annotations,
        updated_dendro_traces,
    ) = plot.plot_model_hierarchical_importance(
        hierarchical_clustering_dendrogram_plot_without_annotations,
        self.result,
        rounding_function,
        digits,
        show_change
    )

    hierarchical_clustering_dendrogram_plot = plot.add_text_to_dendrogram(
        hierarchical_clustering_dendrogram_plot_without_annotations,
        updated_dendro_traces,
        rounding_function, 
        digits,
        type="clustering",
    )

    hierarchical_importance_dendrogram_plot = plot.add_text_to_dendrogram(
        hierarchical_importance_dendrogram_plot_without_annotations,
        updated_dendro_traces,
        rounding_function, 
        digits,
        type="importance",
    )

    ## left plot
    fig = plot.plot_single_aspects_importance(
        self.single_variable_importance,
        variables_order,
        rounding_function,
        digits,
        vcolors
    )
    
    fig.layout["xaxis"]["range"] = (
        fig.layout["xaxis"]["range"][0],
        fig.layout["xaxis"]["range"][1] * 1.05,
    )
    y_vals = [-5 - i * 10 for i in range(m)]
    fig.data[0]["y"] = y_vals

    ## triplot
    min_x_imp, max_x_imp = np.Inf, -np.Inf
    for data in hierarchical_importance_dendrogram_plot["data"][::-1]:
        data["xaxis"] = "x2"
        data["hoverinfo"] = "text"
        data["line"] = {"color": "#46bac2", "width": 2}
        fig.add_trace(data)
        min_x_imp = np.min([min_x_imp, np.min(data["x"])])
        max_x_imp = np.max([max_x_imp, np.max(data["x"])])
    min_max_margin_imp = (max_x_imp - min_x_imp) * 0.15

    min_x_clust, max_x_clust = np.Inf, -np.Inf
    for data in hierarchical_clustering_dendrogram_plot["data"]:
        data["xaxis"] = "x3"
        data["hoverinfo"] = "text"
        data["line"] = {"color": "#46bac2", "width": 2}
        fig.add_trace(data)
        min_x_clust = np.min([min_x_clust, np.min(data["x"])])
        max_x_clust = np.max([max_x_clust, np.max(data["x"])])
    min_max_margin_clust = (max_x_clust - min_x_clust) * 0.15

    plot_height = 78 + 71 + m * bar_width + (m + 1) * bar_width / 4

    fig.update_layout(
        xaxis={
            "autorange": False,
            "domain": [0, 0.33],
            "mirror": False,
            "showgrid": False,
            "showline": False,
            "zeroline": False,
            "showticklabels": True,
            "ticks": "",
            "title_text": "Variable importance",
        },
        xaxis2={
            "domain": [0.33, 0.66],
            "mirror": False,
            "showgrid": False,
            "showline": False,
            "zeroline": False,
            "showticklabels": True,
            "tickvals": [0],
            "ticktext": [""],
            "ticks": "",
            "title_text": "Hierarchical aspect importance",
            "fixedrange": True,
            "autorange": False,
            "range": [
                min_x_imp - min_max_margin_imp,
                max_x_imp + min_max_margin_imp,
            ],
        },
        xaxis3={
            "domain": [0.66, 0.99],
            "mirror": False,
            "showgrid": False,
            "showline": False,
            "zeroline": False,
            "showticklabels": True,
            "tickvals": [0],
            "ticktext": [""],
            "ticks": "",
            "title_text": "Hierarchical clustering",
            "fixedrange": True,
            "autorange": False,
            "range": [
                min_x_clust - min_max_margin_clust,
                max_x_clust + min_max_margin_clust,
            ],
        },
        yaxis={
            "mirror": False,
            "ticks": "",
            "fixedrange": True,
            "gridwidth": 1,
            "type": "linear",
            "tickmode": "array",
            "tickvals": y_vals,
            "ticktext": variables_order,
        },
        title_text=title,
        title_x=0.5,
        font={"color": "#371ea3"},
        template="none",
        margin={"t": 78, "b": 71, "r": 30},
        width=width,
        height=plot_height,
        showlegend=False,
        hovermode="closest",

    )

    fig, middle_point = plot._add_points_on_dendrogram_traces(fig)

    ##################################################################

    if widget:
        _global_checks.global_check_import('ipywidgets', 'Model Triplot')
        from ipywidgets import HBox, Layout
        fig = go.FigureWidget(fig, layout={"autosize": True, "hoverdistance": 100})
        original_bar_colors = deepcopy([fig.data[0]["marker"]["color"]] * m)
        original_text_colors = deepcopy(list(fig.data[0]["textfont"]["color"]))
        k = len(fig.data)
        updated_dendro_traces_in_full_figure = list(
            np.array(updated_dendro_traces) + (k - 1) / 2 + 1
        ) + list((k - 1) / 2 - np.array(updated_dendro_traces))

        def _update_childs(x, y, fig, k, selected, selected_y_cord):
            for i in range(1, k):
                if middle_point[i] == (x, y):
                    fig.data[i]["line"]["color"] = "#46bac2"
                    fig.data[i]["line"]["width"] = 3
                    fig.data[k - i]["line"]["color"] = "#46bac2"
                    fig.data[k - i]["line"]["width"] = 3
                    selected.append(i)
                    selected.append(k - i)
                    if (fig.data[i]["y"][0] + 5) % 10 == 0:
                        selected_y_cord.append((fig.data[i]["y"][0] + 5) // -10)
                    if (fig.data[i]["y"][-1] - 5) % 10 == 0:
                        selected_y_cord.append((fig.data[i]["y"][-1] + 5) // -10)
                    _update_childs(
                        fig.data[i]["x"][0],
                        fig.data[i]["y"][0],
                        fig,
                        k,
                        selected,
                        selected_y_cord,
                    )
                    _update_childs(
                        fig.data[i]["x"][-1],
                        fig.data[i]["y"][-1],
                        fig,
                        k,
                        selected,
                        selected_y_cord,
                    )

        def _update_trace(trace, points, selector):
            if len(points.point_inds) == 1:
                selected_ind = points.trace_index
                with fig.batch_update():
                    if max(fig.data[selected_ind]["x"]) in (max_x_clust, max_x_imp):
                        for i in range(1, k):
                            fig.data[i]["line"]["color"] = "#46bac2"
                            fig.data[i]["line"]["width"] = 2
                            fig.data[i]["textfont"]["color"] = "#371ea3"
                            fig.data[i]["textfont"]["size"] = 12
                        fig.data[0]["marker"]["color"] = original_bar_colors
                        fig.data[0]["textfont"]["color"] = original_text_colors
                    else:
                        selected = [selected_ind, k - selected_ind]
                        selected_y_cord = []
                        if (fig.data[selected_ind]["y"][0] - 5) % 10 == 0:
                            selected_y_cord.append(
                                (fig.data[selected_ind]["y"][0] + 5) // -10
                            )
                        if (fig.data[selected_ind]["y"][-1] - 5) % 10 == 0:
                            selected_y_cord.append(
                                (fig.data[selected_ind]["y"][-1] + 5) // -10
                            )
                        fig.data[selected_ind]["line"]["color"] = "#46bac2"
                        fig.data[selected_ind]["line"]["width"] = 3
                        fig.data[selected_ind]["textfont"]["color"] = "#371ea3"
                        fig.data[selected_ind]["textfont"]["size"] = 14
                        fig.data[k - selected_ind]["line"]["color"] = "#46bac2"
                        fig.data[k - selected_ind]["line"]["width"] = 3
                        fig.data[k - selected_ind]["textfont"]["color"] = "#371ea3"
                        fig.data[k - selected_ind]["textfont"]["size"] = 14
                        _update_childs(
                            fig.data[selected_ind]["x"][0],
                            fig.data[selected_ind]["y"][0],
                            fig,
                            k,
                            selected,
                            selected_y_cord,
                        )
                        _update_childs(
                            fig.data[selected_ind]["x"][-1],
                            fig.data[selected_ind]["y"][-1],
                            fig,
                            k,
                            selected,
                            selected_y_cord,
                        )
                        for i in range(1, k):
                            if i not in [selected_ind, k - selected_ind]:
                                fig.data[i]["textfont"]["color"] = "#ceced9"
                                fig.data[i]["textfont"]["size"] = 12
                                if i not in selected:
                                    fig.data[i]["line"]["color"] = "#ceced9"
                                    fig.data[i]["line"]["width"] = 1

                        bars_colors_list = deepcopy(original_bar_colors)
                        text_colors_list = deepcopy(original_text_colors)
                        for i in range(m):
                            if i not in selected_y_cord:
                                bars_colors_list[i] = "#ceced9"
                                text_colors_list[i] = "#ceced9"
                        fig.data[0]["marker"]["color"] = bars_colors_list
                        fig.data[0]["textfont"]["color"] = text_colors_list

        for i in range(1, k):
            fig.data[i].on_click(_update_trace)
        return HBox([fig], layout=Layout(overflow='scroll', width=f'{fig.layout.width}px'))
    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
class PredictAspectImportance (variable_groups, type='default', N=2000, B=25, n_aspects=None, sample_method='default', f=2, depend_method='assoc', corr_method='spearman', agg_method='max', processes=1, random_state=None, **kwargs)

Calculate predict-level aspect importance

Parameters

variable_groups : dict of lists
Variables grouped in aspects to calculate their importance.
type : {'default', 'shap'}, optional
Type of aspect importance/attributions (default is 'default', which means the use of simplified LIME method).
N : int, optional
Number of observations that will be sampled from the data attribute before the calculation of aspect importance (default is 2000).
B : int, optional
Parameter specific for type == 'shap'. Number of random paths to calculate aspect attributions (default is 25). NOTE: Ignored if type is not 'shap'.
n_aspects : int, optional
Parameter specific for type == 'default'. Maximum number of non-zero importances, i.e. coefficients after lasso fitting (default is None, which means the linear regression is used). NOTE: Ignored if type is not 'default'.
sample_method : {'default', 'binom'}, optional
Parameter specific for type == 'default'. Sampling method for creating binary matrix used as mask for replacing aspects in sampled data (default is 'default', which means it randomly replaces one or two zeros per row; 'binom' replaces random number of zeros per row). NOTE: Ignored if type is not 'default'.
f : int, optional
Parameter specific for type == 'default' and sample_method == 'binom'. Parameter controlling average number of replaced zeros for binomial sampling (default is 2). NOTE: Ignored if type is not 'default' or sample_method is not 'binom'.
depend_method : {'assoc', 'pps'} or function, optional
The method of calculating the dependencies between variables (i.e. the dependency matrix). Default is 'assoc', which means the use of statistical association (correlation coefficient, Cramér's V and eta-quared); 'pps' stands for Power Predictive Score. NOTE: When a function is passed, it is called with the data and it must return a symmetric dependency matrix (pd.DataFrame with variable names as columns and rows).
corr_method : {'spearman', 'pearson', 'kendall'}, optional
The method of calculating correlation between numerical variables (default is 'spearman'). NOTE: Ignored if depend_method is not 'assoc'.
agg_method : {'max', 'min', 'avg'}, optional
The method of aggregating the PPS values for pairs of variables (default is 'max'). NOTE: Ignored if depend_method is not 'pps'.
processes : int, optional
Parameter specific for type == 'shap'. Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
prediction : float
Prediction for new_observation.
intercept : float
Average prediction for data.
variable_groups : dict of lists
Variables grouped in aspects to calculate their importance.
type : {'default', 'shap'}
Type of aspect importance/attributions to calculate.
N : int
Number of observations that will be sampled from the data attribute before the calculation of aspect importance.
B : int
Number of random paths to calculate aspect attributions.
n_aspects : int
Maximum number of non-zero importances.
sample_method : {'default', 'binom'}
Sampling method for creating binary matrix used as mask for replacing aspects in sampled data.
f : int
Average number of replaced zeros for binomial sampling.
depend_method : {'assoc', 'pps'}
The method of calculating the dependencies between variables.
corr_method : {'spearman', 'pearson', 'kendall'}
The method of calculating correlation between numerical variables.
agg_method : {'max', 'min', 'avg'}
The method of aggregating the PPS values for pairs of variables.
processes : int
Number of parallel processes to use in calculations. Iterated over B.
random_state : int
Set seed for random number generator.

Notes

Expand source code Browse git
class PredictAspectImportance(Explanation):
    """Calculate predict-level aspect importance

    Parameters
    -----------
    variable_groups : dict of lists 
        Variables grouped in aspects to calculate their importance. 
    type : {'default', 'shap'}, optional
        Type of aspect importance/attributions (default is `'default'`, which means 
        the use of simplified LIME method).
    N : int, optional
        Number of observations that will be sampled from the `data` attribute
        before the calculation of aspect importance (default is `2000`).
    B : int, optional
        Parameter specific for `type == 'shap'`. Number of random paths to calculate aspect
        attributions (default is `25`).
        NOTE: Ignored if `type` is not `'shap'`.
    n_aspects : int, optional
        Parameter specific for `type == 'default'`. Maximum number of non-zero importances, i.e.
        coefficients after lasso fitting (default is `None`, which means the linear regression is used).
        NOTE: Ignored if `type` is not `'default'`.
    sample_method : {'default', 'binom'}, optional
        Parameter specific for `type == 'default'`. Sampling method for creating binary matrix 
        used as mask for replacing aspects in sampled data (default is `'default'`, which means 
        it randomly replaces one or two zeros per row; `'binom'` replaces random number of zeros 
        per row).
        NOTE: Ignored if `type` is not `'default'`.
    f : int, optional
        Parameter specific for `type == 'default'` and `sample_method == 'binom'`. Parameter 
        controlling average number of replaced zeros for binomial sampling (default is `2`). 
        NOTE: Ignored if `type` is not `'default'` or `sample_method` is not `'binom'`.
    depend_method: {'assoc', 'pps'} or function, optional
        The method of calculating the dependencies between variables (i.e. the dependency 
        matrix). Default is `'assoc'`, which means the use of statistical association 
        (correlation coefficient, Cramér's V and eta-quared); 
        `'pps'` stands for Power Predictive Score.
        NOTE: When a function is passed, it is called with the `data` and it 
        must return a symmetric dependency matrix (`pd.DataFrame` with variable names as 
        columns and rows).
    corr_method : {'spearman', 'pearson', 'kendall'}, optional
        The method of calculating correlation between numerical variables 
        (default is `'spearman'`).
        NOTE: Ignored if `depend_method` is not `'assoc'`.
    agg_method : {'max', 'min', 'avg'}, optional
        The method of aggregating the PPS values for pairs of variables 
        (default is `'max'`).
        NOTE: Ignored if `depend_method` is not `'pps'`. 
    processes : int, optional
        Parameter specific for `type == 'shap'`. Number of parallel processes to use in calculations.
        Iterated over `B` (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    prediction : float
        Prediction for `new_observation`.
    intercept : float
        Average prediction for `data`.
    variable_groups : dict of lists 
        Variables grouped in aspects to calculate their importance. 
    type : {'default', 'shap'}
        Type of aspect importance/attributions to calculate.
    N : int
        Number of observations that will be sampled from the `data` attribute
        before the calculation of aspect importance.
    B : int
        Number of random paths to calculate aspect attributions.
    n_aspects : int
        Maximum number of non-zero importances.
    sample_method : {'default', 'binom'}
        Sampling method for creating binary matrix used as mask for replacing aspects in sampled data.
    f : int
        Average number of replaced zeros for binomial sampling.
    depend_method : {'assoc', 'pps'}
        The method of calculating the dependencies between variables.
    corr_method : {'spearman', 'pearson', 'kendall'}
        The method of calculating correlation between numerical variables.
    agg_method : {'max', 'min', 'avg'}
        The method of aggregating the PPS values for pairs of variables.
    processes : int
        Number of parallel processes to use in calculations. Iterated over `B`.
    random_state : int
        Set seed for random number generator.
    
    Notes
    -----
    - https://arxiv.org/abs/2104.03403
    """
    def __init__(
        self,
        variable_groups,
        type="default",
        N=2000,
        B=25,
        n_aspects=None,
        sample_method="default",
        f=2,
        depend_method="assoc",
        corr_method="spearman",
        agg_method="max",
        processes=1,
        random_state=None,
        **kwargs
    ):
        
        types = ('default', 'shap')
        aliases = {'simplified_lime': 'default', 'lime': 'default', 'shapley_values': 'shap'}
        _type = checks.check_method_type(type, types, aliases)
        _processes = checks.check_processes(processes)
        _random_state = checks.check_random_state(random_state)
        _depend_method, _corr_method, _agg_method = checks.check_method_depend(depend_method, corr_method, agg_method)
        self._depend_matrix = None
        if "_depend_matrix" in kwargs:
            self._depend_matrix = kwargs.get("_depend_matrix")
        
        self.variable_groups = variable_groups
        self.type = _type
        self.N = N
        self.B = B
        self.n_aspects = n_aspects
        self.sample_method = sample_method
        self.f = f
        self.depend_method = _depend_method
        self.corr_method = _corr_method
        self.agg_method = _agg_method
        self.random_state = _random_state
        self.processes = _processes
        self.prediction = None
        self.intercept = None
        self.result = pd.DataFrame()

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self, explainer, new_observation):
        """Calculate the result of explanation
        Fit method makes calculations in place and changes the attributes.

        Parameters
        ----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.
        new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
            An observation for which a prediction needs to be explained.
        
        Returns
        -----------
        None
        """
        _new_observation = checks.check_new_observation(new_observation, explainer)
        checks.check_columns_in_new_observation(_new_observation, explainer)
        _variable_groups = checks.check_variable_groups(self.variable_groups, explainer)
        
        self.prediction = explainer.predict(_new_observation)[0]
        self.intercept = explainer.y_hat.mean()
        
        if self.type == "default":
            self.result = utils.calculate_predict_aspect_importance(
                explainer,
                _new_observation,
                _variable_groups,
                self.N,
                self.n_aspects,
                self.sample_method,
                self.f,
                self.random_state,
            )
        else:
            self.result = utils.calculate_shap_predict_aspect_importance(
                explainer, 
                _new_observation,
                _variable_groups,
                self.N,
                self.B,
                self.processes,
                self.random_state
            )

        self.result.insert(4, "min_depend", None)
        self.result.insert(5, "vars_min_depend", None)

        # if there is _depend_matrix in kwargs (called from Aspect object) 
        if self._depend_matrix is not None:
            vars_min_depend, min_depend = get_min_depend_from_matrix(self._depend_matrix, 
                    self.result.variable_names
                )
        else:
            vars_min_depend, min_depend = calculate_min_depend(
                self.result.variable_names, 
                explainer.data,
                self.depend_method,
                self.corr_method,
                self.agg_method,
            )

        self.result["min_depend"] = min_depend
        self.result["vars_min_depend"] = vars_min_depend

    def plot(
        self,
        objects=None,
        baseline=None,
        max_aspects=10,
        show_variable_names=True,
        digits=3,
        rounding_function=np.around,
        bar_width=25,
        min_max=None,
        vcolors=None,
        title="Predict Aspect Importance",
        vertical_spacing=None,
        show=True,
    ):
        """Plot the Predict Aspect Importance explanation.

        Parameters
        ----------
        objects : PredictAspectImportance object or array_like of PredictAspectImportance objects
            Additional objects to plot in subplots (default is `None`).
        baseline: float, optional
            Starting x point for bars 
            (default is 0 if `type` is `'default'` and average prediction if `type` is `'shap'`).
        max_aspects : int, optional
            Maximum number of aspects that will be presented for for each subplot
            (default is `10`).
        show_variable_names : bool, optional
            `True` shows names of variables grouped in aspects; `False` shows names of aspects
            (default is `True`).
        digits : int, optional
            Number of decimal places (`np.around`) to round contributions.
            See `rounding_function` parameter (default is `3`).
        rounding_function : function, optional
            A function that will be used for rounding numbers (default is `np.around`).
        bar_width : float, optional
            Width of bars in px (default is `16`).
        min_max : 2-tuple of float, optional
            Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`).
        vcolors : 2-tuple of str, optional
            Color of bars (default is `["#8bdcbe", "#f05a71"]`).
        title : str, optional
            Title of the plot (default is `"Predict Aspect Importance"`).
        vertical_spacing : float <0, 1>, optional
            Ratio of vertical space between the plots (default is `0.2/number of rows`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can 
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        _result_list = [self.result.copy()]
        _intercept_list = [self.intercept]
        # are there any other objects to plot?
        if objects is None:
            n = 1
        elif isinstance(objects, self.__class__):
            n = 2
            _result_list += [objects.result.copy()]
            _intercept_list += [objects.intercept]
        elif isinstance(objects, (list, tuple)):
            n = len(objects) + 1
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _result_list += [ob.result.copy()]
                _intercept_list += [ob.intercept]
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        model_names = [
            result.iloc[0, result.columns.get_loc("label")] for result in _result_list
        ]

        if vertical_spacing is None:
            vertical_spacing = 0.2 / n

        # generate plot
        fig = make_subplots(
            rows=n,
            cols=1,
            shared_xaxes=True,
            vertical_spacing=vertical_spacing,
            x_title="aspect importance",
            subplot_titles=model_names,
        )

        plot_height = 78 + 71

        if vcolors is None:
            vcolors = _theme.get_aspect_importance_colors()

        if min_max is None:
            temp_min_max = [np.Inf, -np.Inf]
        else:
            temp_min_max = min_max

        for i, _result in enumerate(_result_list):
            if _result.shape[0] <= max_aspects:
                m = _result.shape[0]
            else:
                m = max_aspects + 1

            if baseline is None:
                if self.type == 'shap':
                    baseline = _intercept_list[i]
                else: 
                    baseline = 0 
            
            _result = _result.iloc[:max_aspects, :]
            _result.loc[:, "importance"] = rounding_function(
                _result.loc[:, "importance"], digits
            )

            _result["color"] = [0 if imp > 0 else 1 for imp in _result["importance"]]
            _result["tooltip_text"] = _result.apply(
                lambda row: plot.tooltip_text(row, rounding_function, digits, self.type),
                axis=1,
            )
            _result["label_text"] = _global_utils.convert_float_to_str(
                _result.importance, "+"
            )

            fig.add_shape(
                type="line",
                x0=baseline,
                x1=baseline,
                y0=-1,
                y1=m,
                yref="paper",
                xref="x",
                line={"color": "#371ea3", "width": 1.5, "dash": "dot"},
                row=i + 1,
                col=1,
            )

            fig.add_bar(
                orientation="h",
                y=[
                    ", ".join(variables_list)
                    for variables_list in _result["variable_names"]
                ]
                if show_variable_names
                else _result["aspect_name"].tolist(),
                x=_result["importance"].tolist(),
                textposition="outside",
                text=_result["label_text"].tolist(),
                marker_color=[vcolors[int(c)] for c in _result["color"].tolist()],
                base=baseline,
                hovertext=_result["tooltip_text"].tolist(),
                hoverinfo="text",
                hoverlabel={"bgcolor": "rgba(0,0,0,0.8)"},
                showlegend=False,
                row=i + 1,
                col=1,
            )

            fig.update_yaxes(
                {
                    "type": "category",
                    "autorange": "reversed",
                    "gridwidth": 2,
                    "automargin": True,
                    "ticks": "outside",
                    "tickcolor": "white",
                    "ticklen": 10,
                    "fixedrange": True,
                },
                row=i + 1,
                col=1,
            )

            fig.update_xaxes(
                {
                    "type": "linear",
                    "gridwidth": 2,
                    "zeroline": False,
                    "automargin": True,
                    "ticks": "outside",
                    "tickcolor": "white",
                    "ticklen": 3,
                    "fixedrange": True,
                },
                row=i + 1,
                col=1,
            )

            plot_height += m * bar_width + (m + 1) * bar_width / 4

            if min_max is None:
                cum = _result.importance.values + baseline
                min_max_margin =  cum.ptp() * 0.15 
                temp_min_max[0] = np.min(
                    [
                        temp_min_max[0],
                        cum.min() - min_max_margin,
                    ]
                )
                temp_min_max[1] = np.max(
                    [
                        temp_min_max[1],
                        cum.max() + min_max_margin,
                    ]
                )

        plot_height += (n - 1) * 70

        fig.update_xaxes({"range": temp_min_max})
        fig.update_layout(
            title_text=title,
            title_x=0.15,
            font={"color": "#371ea3"},
            template="none",
            height=plot_height,
            margin={"t": 78, "b": 71, "r": 30},
        )

        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, explainer, new_observation)

Calculate the result of explanation Fit method makes calculations in place and changes the attributes.

Parameters

explainer : Explainer dalex.aspect.object
Model wrapper created using the Explainer class.
new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
An observation for which a prediction needs to be explained.

Returns

None
 
Expand source code Browse git
def fit(self, explainer, new_observation):
    """Calculate the result of explanation
    Fit method makes calculations in place and changes the attributes.

    Parameters
    ----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.
    new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
        An observation for which a prediction needs to be explained.
    
    Returns
    -----------
    None
    """
    _new_observation = checks.check_new_observation(new_observation, explainer)
    checks.check_columns_in_new_observation(_new_observation, explainer)
    _variable_groups = checks.check_variable_groups(self.variable_groups, explainer)
    
    self.prediction = explainer.predict(_new_observation)[0]
    self.intercept = explainer.y_hat.mean()
    
    if self.type == "default":
        self.result = utils.calculate_predict_aspect_importance(
            explainer,
            _new_observation,
            _variable_groups,
            self.N,
            self.n_aspects,
            self.sample_method,
            self.f,
            self.random_state,
        )
    else:
        self.result = utils.calculate_shap_predict_aspect_importance(
            explainer, 
            _new_observation,
            _variable_groups,
            self.N,
            self.B,
            self.processes,
            self.random_state
        )

    self.result.insert(4, "min_depend", None)
    self.result.insert(5, "vars_min_depend", None)

    # if there is _depend_matrix in kwargs (called from Aspect object) 
    if self._depend_matrix is not None:
        vars_min_depend, min_depend = get_min_depend_from_matrix(self._depend_matrix, 
                self.result.variable_names
            )
    else:
        vars_min_depend, min_depend = calculate_min_depend(
            self.result.variable_names, 
            explainer.data,
            self.depend_method,
            self.corr_method,
            self.agg_method,
        )

    self.result["min_depend"] = min_depend
    self.result["vars_min_depend"] = vars_min_depend
def plot(self, objects=None, baseline=None, max_aspects=10, show_variable_names=True, digits=3, rounding_function=<function around>, bar_width=25, min_max=None, vcolors=None, title='Predict Aspect Importance', vertical_spacing=None, show=True)

Plot the Predict Aspect Importance explanation.

Parameters

objects : PredictAspectImportance dalex.aspect.object or array_like of PredictAspectImportance objects
Additional objects to plot in subplots (default is None).
baseline : float, optional
Starting x point for bars (default is 0 if type is 'default' and average prediction if type is 'shap').
max_aspects : int, optional
Maximum number of aspects that will be presented for for each subplot (default is 10).
show_variable_names : bool, optional
True shows names of variables grouped in aspects; False shows names of aspects (default is True).
digits : int, optional
Number of decimal places (np.around) to round contributions. See rounding_function parameter (default is 3).
rounding_function : function, optional
A function that will be used for rounding numbers (default is np.around).
bar_width : float, optional
Width of bars in px (default is 16).
min_max : 2-tuple of float, optional
Range of OX axis (default is [min-0.15*(max-min), max+0.15*(max-min)]).
vcolors : 2-tuple of str, optional
Color of bars (default is ["#8bdcbe", "#f05a71"]).
title : str, optional
Title of the plot (default is "Predict Aspect Importance").
vertical_spacing : float <0, 1>, optional
Ratio of vertical space between the plots (default is 0.2/number of rows).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(
    self,
    objects=None,
    baseline=None,
    max_aspects=10,
    show_variable_names=True,
    digits=3,
    rounding_function=np.around,
    bar_width=25,
    min_max=None,
    vcolors=None,
    title="Predict Aspect Importance",
    vertical_spacing=None,
    show=True,
):
    """Plot the Predict Aspect Importance explanation.

    Parameters
    ----------
    objects : PredictAspectImportance object or array_like of PredictAspectImportance objects
        Additional objects to plot in subplots (default is `None`).
    baseline: float, optional
        Starting x point for bars 
        (default is 0 if `type` is `'default'` and average prediction if `type` is `'shap'`).
    max_aspects : int, optional
        Maximum number of aspects that will be presented for for each subplot
        (default is `10`).
    show_variable_names : bool, optional
        `True` shows names of variables grouped in aspects; `False` shows names of aspects
        (default is `True`).
    digits : int, optional
        Number of decimal places (`np.around`) to round contributions.
        See `rounding_function` parameter (default is `3`).
    rounding_function : function, optional
        A function that will be used for rounding numbers (default is `np.around`).
    bar_width : float, optional
        Width of bars in px (default is `16`).
    min_max : 2-tuple of float, optional
        Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`).
    vcolors : 2-tuple of str, optional
        Color of bars (default is `["#8bdcbe", "#f05a71"]`).
    title : str, optional
        Title of the plot (default is `"Predict Aspect Importance"`).
    vertical_spacing : float <0, 1>, optional
        Ratio of vertical space between the plots (default is `0.2/number of rows`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can 
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    _result_list = [self.result.copy()]
    _intercept_list = [self.intercept]
    # are there any other objects to plot?
    if objects is None:
        n = 1
    elif isinstance(objects, self.__class__):
        n = 2
        _result_list += [objects.result.copy()]
        _intercept_list += [objects.intercept]
    elif isinstance(objects, (list, tuple)):
        n = len(objects) + 1
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _result_list += [ob.result.copy()]
            _intercept_list += [ob.intercept]
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    model_names = [
        result.iloc[0, result.columns.get_loc("label")] for result in _result_list
    ]

    if vertical_spacing is None:
        vertical_spacing = 0.2 / n

    # generate plot
    fig = make_subplots(
        rows=n,
        cols=1,
        shared_xaxes=True,
        vertical_spacing=vertical_spacing,
        x_title="aspect importance",
        subplot_titles=model_names,
    )

    plot_height = 78 + 71

    if vcolors is None:
        vcolors = _theme.get_aspect_importance_colors()

    if min_max is None:
        temp_min_max = [np.Inf, -np.Inf]
    else:
        temp_min_max = min_max

    for i, _result in enumerate(_result_list):
        if _result.shape[0] <= max_aspects:
            m = _result.shape[0]
        else:
            m = max_aspects + 1

        if baseline is None:
            if self.type == 'shap':
                baseline = _intercept_list[i]
            else: 
                baseline = 0 
        
        _result = _result.iloc[:max_aspects, :]
        _result.loc[:, "importance"] = rounding_function(
            _result.loc[:, "importance"], digits
        )

        _result["color"] = [0 if imp > 0 else 1 for imp in _result["importance"]]
        _result["tooltip_text"] = _result.apply(
            lambda row: plot.tooltip_text(row, rounding_function, digits, self.type),
            axis=1,
        )
        _result["label_text"] = _global_utils.convert_float_to_str(
            _result.importance, "+"
        )

        fig.add_shape(
            type="line",
            x0=baseline,
            x1=baseline,
            y0=-1,
            y1=m,
            yref="paper",
            xref="x",
            line={"color": "#371ea3", "width": 1.5, "dash": "dot"},
            row=i + 1,
            col=1,
        )

        fig.add_bar(
            orientation="h",
            y=[
                ", ".join(variables_list)
                for variables_list in _result["variable_names"]
            ]
            if show_variable_names
            else _result["aspect_name"].tolist(),
            x=_result["importance"].tolist(),
            textposition="outside",
            text=_result["label_text"].tolist(),
            marker_color=[vcolors[int(c)] for c in _result["color"].tolist()],
            base=baseline,
            hovertext=_result["tooltip_text"].tolist(),
            hoverinfo="text",
            hoverlabel={"bgcolor": "rgba(0,0,0,0.8)"},
            showlegend=False,
            row=i + 1,
            col=1,
        )

        fig.update_yaxes(
            {
                "type": "category",
                "autorange": "reversed",
                "gridwidth": 2,
                "automargin": True,
                "ticks": "outside",
                "tickcolor": "white",
                "ticklen": 10,
                "fixedrange": True,
            },
            row=i + 1,
            col=1,
        )

        fig.update_xaxes(
            {
                "type": "linear",
                "gridwidth": 2,
                "zeroline": False,
                "automargin": True,
                "ticks": "outside",
                "tickcolor": "white",
                "ticklen": 3,
                "fixedrange": True,
            },
            row=i + 1,
            col=1,
        )

        plot_height += m * bar_width + (m + 1) * bar_width / 4

        if min_max is None:
            cum = _result.importance.values + baseline
            min_max_margin =  cum.ptp() * 0.15 
            temp_min_max[0] = np.min(
                [
                    temp_min_max[0],
                    cum.min() - min_max_margin,
                ]
            )
            temp_min_max[1] = np.max(
                [
                    temp_min_max[1],
                    cum.max() + min_max_margin,
                ]
            )

    plot_height += (n - 1) * 70

    fig.update_xaxes({"range": temp_min_max})
    fig.update_layout(
        title_text=title,
        title_x=0.15,
        font={"color": "#371ea3"},
        template="none",
        height=plot_height,
        margin={"t": 78, "b": 71, "r": 30},
    )

    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
class PredictTriplot (type='default', N=2000, B=25, sample_method='default', f=2, processes=1, random_state=None)

Calculate predict-level hierarchical aspect importance

Parameters

type : {'default', 'shap'}, optional
Type of aspect importance/attributions (default is 'default', which means the use of simplified LIME method).
N : int, optional
Number of observations that will be sampled from the explainer.data attribute before the calculation of aspect importance (default is 2000).
B : int, optional
Parameter specific for type == 'shap'. Number of random paths to calculate aspect attributions (default is 25). NOTE: Ignored if type is not 'shap'.
sample_method : {'default', 'binom'}, optional
Parameter specific for type == 'default'. Sampling method for creating binary matrix used as mask for replacing aspects in data (default is 'default', which means it randomly replaces one or two zeros per row; 'binom' replaces random number of zeros per row). NOTE: Ignored if type is not 'default'.
f : int, optional
Parameter specific for type == 'default' and sample_method == 'binom'. Parameter controlling average number of replaced zeros for binomial sampling (default is 2). NOTE: Ignored if type is not 'default' or sample_method is not 'binom'.
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
single_variable_importance : pd.DataFrame
Additional result attribute of an explanation (it contains information about the importance of individual variables).
prediction : float
Prediction for new_observation.
intercept : float
Average prediction for data.
type : {'default', 'shap'}
Type of aspect importance/attributions to calculate.
N : int
Number of observations that will be sampled from the data attribute before the calculation of aspect importance.
B : int
Number of random paths to calculate aspect attributions.
sample_method : {'default', 'binom'}
Sampling method for creating binary matrix used as mask for replacing aspects in sampled data.
f : int
Average number of replaced zeros for binomial sampling.
processes : int
Number of parallel processes to use in calculations. Iterated over B.
random_state : int or None
Set seed for random number generator.

Notes

Expand source code Browse git
class PredictTriplot(Explanation):
    """Calculate predict-level hierarchical aspect importance

    Parameters
    ----------
    type : {'default', 'shap'}, optional
        Type of aspect importance/attributions (default is `'default'`, which means 
        the use of simplified LIME method).
    N : int, optional
        Number of observations that will be sampled from the `explainer.data` attribute
        before the calculation of aspect importance (default is `2000`).
    B : int, optional
        Parameter specific for `type == 'shap'`. Number of random paths to calculate aspect
        attributions (default is `25`).
        NOTE: Ignored if `type` is not `'shap'`.
    sample_method : {'default', 'binom'}, optional
        Parameter specific for `type == 'default'`. Sampling method for creating binary matrix 
        used as mask for replacing aspects in data (default is `'default'`, which means 
        it randomly replaces one or two zeros per row; `'binom'` replaces random number of zeros 
        per row).
        NOTE: Ignored if `type` is not `'default'`.
    f : int, optional
        Parameter specific for `type == 'default'` and `sample_method == 'binom'`. Parameter 
        controlling average number of replaced zeros for binomial sampling (default is `2`). 
        NOTE: Ignored if `type` is not `'default'` or `sample_method` is not `'binom'`.
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    single_variable_importance : pd.DataFrame
        Additional result attribute of an explanation (it contains information 
        about the importance of individual variables).
    prediction : float
        Prediction for `new_observation`.
    intercept : float
        Average prediction for `data`.
    type : {'default', 'shap'}
        Type of aspect importance/attributions to calculate.
    N : int
        Number of observations that will be sampled from the `data` attribute
        before the calculation of aspect importance.
    B : int
        Number of random paths to calculate aspect attributions.
    sample_method : {'default', 'binom'}
        Sampling method for creating binary matrix used as mask for replacing aspects in sampled data.
    f : int
        Average number of replaced zeros for binomial sampling.
    processes : int
        Number of parallel processes to use in calculations. Iterated over `B`.
    random_state : int or None
        Set seed for random number generator.

    Notes
    -----
    - https://arxiv.org/abs/2104.03403
    """
    def __init__(
        self,
        type="default",
        N=2000,
        B=25,
        sample_method="default",
        f=2,
        processes=1,
        random_state=None,
    ):
        types = ("default", "shap")
        aliases = {"simplified_lime": "default", "lime": "default"}
        _type = checks.check_method_type(type, types, aliases)
        self.type = _type
        _processes = checks.check_processes(processes)
        self.processes = _processes
        _random_state = checks.check_random_state(random_state)
        self.random_state = _random_state
        self.N = N
        self.B = checks.check_B(B)
        self.sample_method = sample_method
        self.f = f
        self.result = pd.DataFrame()
        self.single_variable_importance = None
        self._hierarchical_clustering_dendrogram = None

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self, aspect, new_observation):
        """Calculate the result of explanation
        Fit method makes calculations in place and changes the attributes.

        Parameters
        ----------
        aspect : Aspect object
            Explainer wrapper created using the Aspect class.
        new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
            An observation for which a prediction needs to be explained.
        
        Returns
        -----------
        None
        """

        _new_observation = checks.check_new_observation(new_observation, aspect.explainer)
        checks.check_columns_in_new_observation(_new_observation, aspect.explainer)

        self._hierarchical_clustering_dendrogram = aspect._hierarchical_clustering_dendrogram

        self.prediction = aspect.explainer.predict(_new_observation)[0]
        self.intercept = aspect.explainer.y_hat.mean()

        ## middle plot data
        self.result = (
            utils.calculate_predict_hierarchical_importance(
                aspect,
                _new_observation,
                self.type,
                self.N,
                self.B,
                self.sample_method,
                self.f,
                self.processes,
                self.random_state,
            )
        )

        self.result.insert(3, "min_depend", None)
        self.result.insert(4, "vars_min_depend", None)
        for index, row in self.result.iterrows():
            _matching_row = aspect._dendrogram_aspects_ordered.loc[
                pd.Series(map(set, aspect._dendrogram_aspects_ordered.variable_names))
                == set(row.variable_names)
            ]
            min_dep = _matching_row.min_depend.values[0]
            vars_min_depend = _matching_row.vars_min_depend.values[0]
            self.result.at[index, "min_depend"] = min_dep
            self.result.at[
                index, "vars_min_depend"
            ] = vars_min_depend

        ## left plot data
        self.single_variable_importance = utils.calculate_single_variable_importance(
            aspect,
            _new_observation,
            self.type,
            self.N,
            self.B, 
            self.sample_method,
            self.f,
            self.processes,
            self.random_state
        )
        
    def plot(
        self,
        absolute_value=False,
        digits=3,
        rounding_function=np.around,
        bar_width=25,
        width=1500,
        abbrev_labels=0,
        vcolors=None,
        title="Predict Triplot",
        widget=False,
        show=True
    ):
        """Plot the Predict Triplot explanation (triplot visualization).

        Parameters
        ----------
        absolute_value : bool, optional
            If `True` aspect importance values are drawn as absolute values 
            (default is `False`).
        digits : int, optional
            Number of decimal places (`np.around`) to round contributions.
            See `rounding_function` parameter (default is `3`).
        rounding_function : function, optional
            A function that will be used for rounding numbers (default is `np.around`).
        bar_width : float, optional
            Width of bars in px (default is `25`).
        width : float, optional
            Width of triplot in px (default is `1500`).
        abbrev_labels : int, optional
            If greater than 0, labels for axis Y will be abbreviated according to this parameter
            (default is `0`, which means no abbreviation).
        vcolors : 2-tuple of str, optional
            Color of bars (default is `["#8bdcbe", "#f05a71"]`).
        title : str, optional
            Title of the plot (default is `"Predict Triplot"`).
        widget : bool, optional
            If `True` triplot interactive widget version is generated
            (default is `False`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object 
            (default is `True`).
            NOTE: Ignored if `widget` is `True`.

        Returns
        -------
        None or plotly.graph_objects.Figure or ipywidgets.HBox with plotly.graph_objs._figurewidget.FigureWidget
            Return figure that can be edited or saved. See `show` parameter.
        """    
        _global_checks.global_check_import('kaleido', 'Predict Triplot')
        ## right plot
        hierarchical_clustering_dendrogram_plot_without_annotations = go.Figure(
            self._hierarchical_clustering_dendrogram
        )
        variables_order = list(
            hierarchical_clustering_dendrogram_plot_without_annotations.layout.yaxis.ticktext
        )

        ## middle plot
        (
            hierarchical_importance_dendrogram_plot_without_annotations,
            updated_dendro_traces,
        ) = plot.plot_predict_hierarchical_importance(
            hierarchical_clustering_dendrogram_plot_without_annotations,
            self.result,
            rounding_function,
            digits,
            absolute_value,
            self.type,
        )

        hierarchical_clustering_dendrogram_plot = plot.add_text_to_dendrogram(
            hierarchical_clustering_dendrogram_plot_without_annotations,
            updated_dendro_traces,
            rounding_function,
            digits,
            type="clustering",
        )

        hierarchical_importance_dendrogram_plot = plot.add_text_to_dendrogram(
            hierarchical_importance_dendrogram_plot_without_annotations,
            updated_dendro_traces,
            rounding_function,
            digits,
            type="importance",
        )

        ## left plot
        fig = plot.plot_single_aspects_importance(
            self.single_variable_importance,
            variables_order,
            rounding_function,
            digits,
            vcolors,
        )
        fig.layout["xaxis"]["range"] = (
            fig.layout["xaxis"]["range"][0],
            fig.layout["xaxis"]["range"][1] * 1.05,
        )
        m = len(variables_order)
        y_vals = [-5 - i * 10 for i in range(m)]
        fig.data[0]["y"] = y_vals

        ## triplot

        fig.add_shape(
            type="line",
            x0=0,
            x1=0,
            y0=-0.01,
            y1=1.01,
            yref="paper",
            xref="x2",
            line={"color": "#371ea3", "width": 1.5, "dash": "dot"},
        )

        min_x_imp, max_x_imp = np.Inf, -np.Inf
        for data in hierarchical_importance_dendrogram_plot["data"][::-1]:
            data["xaxis"] = "x2"
            data["hoverinfo"] = "text"
            data["line"] = {"color": "#46bac2", "width": 2}
            fig.add_trace(data)
            min_x_imp = np.min([min_x_imp, np.min(data["x"])])
            max_x_imp = np.max([max_x_imp, np.max(data["x"])])
        min_max_margin_imp = (max_x_imp - min_x_imp) * 0.15

        min_x_clust, max_x_clust = np.Inf, -np.Inf
        for data in hierarchical_clustering_dendrogram_plot["data"]:
            data["xaxis"] = "x3"
            data["hoverinfo"] = "text"
            data["line"] = {"color": "#46bac2", "width": 2}
            fig.add_trace(data)
            min_x_clust = np.min([min_x_clust, np.min(data["x"])])
            max_x_clust = np.max([max_x_clust, np.max(data["x"])])
        min_max_margin_clust = (max_x_clust - min_x_clust) * 0.15

        plot_height = 78 + 71 + m * bar_width + (m + 1) * bar_width / 4
        ticktext = plot.get_ticktext_for_plot(
            self.single_variable_importance, variables_order, abbrev_labels
        )

        fig.update_layout(
            xaxis={
                "autorange": False,
                "domain": [0, 0.33],
                "mirror": False,
                "showgrid": False,
                "showline": False,
                "zeroline": False,
                "ticks": "",
                "title_text": "Local variable importance",
            },
            xaxis2={
                "domain": [0.33, 0.66],
                "mirror": False,
                "showgrid": False,
                "showline": False,
                "zeroline": False,
                "showticklabels": True,
                "tickvals": [0],
                "ticktext": [""],
                "ticks": "",
                "title_text": "Hierarchical aspect importance",
                "autorange": False,
                "fixedrange": True,
                "range": [
                    min_x_imp - min_max_margin_imp,
                    max_x_imp + min_max_margin_imp,
                ],
            },
            xaxis3={
                "domain": [0.66, 0.99],
                "mirror": False,
                "showgrid": False,
                "showline": False,
                "zeroline": False,
                "showticklabels": True,
                "tickvals": [0],
                "ticktext": [""],
                "ticks": "",
                "title_text": "Hierarchical clustering",
                "autorange": False,
                "fixedrange": True,
                "range": [
                    min_x_clust - min_max_margin_clust,
                    max_x_clust + min_max_margin_clust,
                ],
            },
            yaxis={
                "mirror": False,
                "ticks": "",
                "fixedrange": True,
                "gridwidth": 1,
                "type": "linear",
                "tickmode": "array",
                "tickvals": y_vals,
                "ticktext": ticktext,
            },
            title_text=title,
            title_x=0.5,
            font={"color": "#371ea3"},
            template="none",
            margin={"t": 78, "b": 71, "r": 30},
            width=width,
            height=plot_height,
            showlegend=False,
            hovermode="closest"
        )

        fig, middle_point = plot._add_points_on_dendrogram_traces(fig)

        ##################################################################

        if widget:
            _global_checks.global_check_import('ipywidgets', 'Predict Triplot')
            from ipywidgets import HBox, Layout
            fig = go.FigureWidget(fig, layout={"autosize": True, "hoverdistance": 100})
            original_bar_colors = deepcopy(list(fig.data[0]["marker"]["color"]))
            original_text_colors = deepcopy(list(fig.data[0]["textfont"]["color"]))
            k = len(fig.data)
            updated_dendro_traces_in_full_figure = list(
                np.array(updated_dendro_traces) + (k - 1) / 2 + 1
            ) + list((k - 1) / 2 - np.array(updated_dendro_traces))

            def _update_childs(x, y, fig, k, selected, selected_y_cord):
                for i in range(1, k):
                    if middle_point[i] == (x, y):
                        fig.data[i]["line"]["color"] = "#46bac2"
                        fig.data[i]["line"]["width"] = 3
                        fig.data[k - i]["line"]["color"] = "#46bac2"
                        fig.data[k - i]["line"]["width"] = 3
                        selected.append(i)
                        selected.append(k - i)
                        if (fig.data[i]["y"][0] + 5) % 10 == 0:
                            selected_y_cord.append((fig.data[i]["y"][0] + 5) // -10)
                        if (fig.data[i]["y"][-1] - 5) % 10 == 0:
                            selected_y_cord.append((fig.data[i]["y"][-1] + 5) // -10)
                        _update_childs(
                            fig.data[i]["x"][0],
                            fig.data[i]["y"][0],
                            fig,
                            k,
                            selected,
                            selected_y_cord,
                        )
                        _update_childs(
                            fig.data[i]["x"][-1],
                            fig.data[i]["y"][-1],
                            fig,
                            k,
                            selected,
                            selected_y_cord,
                        )

            def _update_trace(trace, points, selector):
                if len(points.point_inds) == 1:
                    selected_ind = points.trace_index
                    with fig.batch_update():
                        if selected_ind not in updated_dendro_traces_in_full_figure:
                            for i in range(1, k):
                                fig.data[i]["line"]["color"] = "#46bac2"
                                fig.data[i]["line"]["width"] = 2
                                fig.data[i]["textfont"]["color"] = "#371ea3"
                                fig.data[i]["textfont"]["size"] = 12
                            fig.data[0]["marker"]["color"] = original_bar_colors
                            fig.data[0]["textfont"]["color"] = original_text_colors
                        else:
                            selected = [selected_ind, k - selected_ind]
                            selected_y_cord = []
                            if (fig.data[selected_ind]["y"][0] - 5) % 10 == 0:
                                selected_y_cord.append(
                                    (fig.data[selected_ind]["y"][0] + 5) // -10
                                )
                            if (fig.data[selected_ind]["y"][-1] - 5) % 10 == 0:
                                selected_y_cord.append(
                                    (fig.data[selected_ind]["y"][-1] + 5) // -10
                                )
                            fig.data[selected_ind]["line"]["color"] = "#46bac2"
                            fig.data[selected_ind]["line"]["width"] = 3
                            fig.data[selected_ind]["textfont"]["color"] = "#371ea3"
                            fig.data[selected_ind]["textfont"]["size"] = 14
                            fig.data[k - selected_ind]["line"]["color"] = "#46bac2"
                            fig.data[k - selected_ind]["line"]["width"] = 3
                            fig.data[k - selected_ind]["textfont"]["color"] = "#371ea3"
                            fig.data[k - selected_ind]["textfont"]["size"] = 14
                            _update_childs(
                                fig.data[selected_ind]["x"][0],
                                fig.data[selected_ind]["y"][0],
                                fig,
                                k,
                                selected,
                                selected_y_cord,
                            )
                            _update_childs(
                                fig.data[selected_ind]["x"][-1],
                                fig.data[selected_ind]["y"][-1],
                                fig,
                                k,
                                selected,
                                selected_y_cord,
                            )
                            for i in range(1, k):
                                if i not in [selected_ind, k - selected_ind]:
                                    fig.data[i]["textfont"]["color"] = "#ceced9"
                                    fig.data[i]["textfont"]["size"] = 12
                                    if i not in selected:
                                        fig.data[i]["line"]["color"] = "#ceced9"
                                        fig.data[i]["line"]["width"] = 1

                            bars_colors_list = deepcopy(original_bar_colors)
                            text_colors_list = deepcopy(original_text_colors)
                            for i in range(m):
                                if i not in selected_y_cord:
                                    bars_colors_list[i] = "#ceced9"
                                    text_colors_list[i] = "#ceced9"
                            fig.data[0]["marker"]["color"] = bars_colors_list
                            fig.data[0]["textfont"]["color"] = text_colors_list

            for i in range(1, k):
                fig.data[i].on_click(_update_trace)
            return HBox([fig], layout=Layout(overflow='scroll', width=f'{fig.layout.width}px'))
        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, aspect, new_observation)

Calculate the result of explanation Fit method makes calculations in place and changes the attributes.

Parameters

aspect : Aspect dalex.aspect.object
Explainer wrapper created using the Aspect class.
new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
An observation for which a prediction needs to be explained.

Returns

None
 
Expand source code Browse git
def fit(self, aspect, new_observation):
    """Calculate the result of explanation
    Fit method makes calculations in place and changes the attributes.

    Parameters
    ----------
    aspect : Aspect object
        Explainer wrapper created using the Aspect class.
    new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
        An observation for which a prediction needs to be explained.
    
    Returns
    -----------
    None
    """

    _new_observation = checks.check_new_observation(new_observation, aspect.explainer)
    checks.check_columns_in_new_observation(_new_observation, aspect.explainer)

    self._hierarchical_clustering_dendrogram = aspect._hierarchical_clustering_dendrogram

    self.prediction = aspect.explainer.predict(_new_observation)[0]
    self.intercept = aspect.explainer.y_hat.mean()

    ## middle plot data
    self.result = (
        utils.calculate_predict_hierarchical_importance(
            aspect,
            _new_observation,
            self.type,
            self.N,
            self.B,
            self.sample_method,
            self.f,
            self.processes,
            self.random_state,
        )
    )

    self.result.insert(3, "min_depend", None)
    self.result.insert(4, "vars_min_depend", None)
    for index, row in self.result.iterrows():
        _matching_row = aspect._dendrogram_aspects_ordered.loc[
            pd.Series(map(set, aspect._dendrogram_aspects_ordered.variable_names))
            == set(row.variable_names)
        ]
        min_dep = _matching_row.min_depend.values[0]
        vars_min_depend = _matching_row.vars_min_depend.values[0]
        self.result.at[index, "min_depend"] = min_dep
        self.result.at[
            index, "vars_min_depend"
        ] = vars_min_depend

    ## left plot data
    self.single_variable_importance = utils.calculate_single_variable_importance(
        aspect,
        _new_observation,
        self.type,
        self.N,
        self.B, 
        self.sample_method,
        self.f,
        self.processes,
        self.random_state
    )
def plot(self, absolute_value=False, digits=3, rounding_function=<function around>, bar_width=25, width=1500, abbrev_labels=0, vcolors=None, title='Predict Triplot', widget=False, show=True)

Plot the Predict Triplot explanation (triplot visualization).

Parameters

absolute_value : bool, optional
If True aspect importance values are drawn as absolute values (default is False).
digits : int, optional
Number of decimal places (np.around) to round contributions. See rounding_function parameter (default is 3).
rounding_function : function, optional
A function that will be used for rounding numbers (default is np.around).
bar_width : float, optional
Width of bars in px (default is 25).
width : float, optional
Width of triplot in px (default is 1500).
abbrev_labels : int, optional
If greater than 0, labels for axis Y will be abbreviated according to this parameter (default is 0, which means no abbreviation).
vcolors : 2-tuple of str, optional
Color of bars (default is ["#8bdcbe", "#f05a71"]).
title : str, optional
Title of the plot (default is "Predict Triplot").
widget : bool, optional
If True triplot interactive widget version is generated (default is False).
show : bool, optional
True shows the plot; False returns the plotly Figure object (default is True). NOTE: Ignored if widget is True.

Returns

None or plotly.graph_objects.Figure or ipywidgets.HBox with plotly.graph_objs._figurewidget.FigureWidget
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(
    self,
    absolute_value=False,
    digits=3,
    rounding_function=np.around,
    bar_width=25,
    width=1500,
    abbrev_labels=0,
    vcolors=None,
    title="Predict Triplot",
    widget=False,
    show=True
):
    """Plot the Predict Triplot explanation (triplot visualization).

    Parameters
    ----------
    absolute_value : bool, optional
        If `True` aspect importance values are drawn as absolute values 
        (default is `False`).
    digits : int, optional
        Number of decimal places (`np.around`) to round contributions.
        See `rounding_function` parameter (default is `3`).
    rounding_function : function, optional
        A function that will be used for rounding numbers (default is `np.around`).
    bar_width : float, optional
        Width of bars in px (default is `25`).
    width : float, optional
        Width of triplot in px (default is `1500`).
    abbrev_labels : int, optional
        If greater than 0, labels for axis Y will be abbreviated according to this parameter
        (default is `0`, which means no abbreviation).
    vcolors : 2-tuple of str, optional
        Color of bars (default is `["#8bdcbe", "#f05a71"]`).
    title : str, optional
        Title of the plot (default is `"Predict Triplot"`).
    widget : bool, optional
        If `True` triplot interactive widget version is generated
        (default is `False`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object 
        (default is `True`).
        NOTE: Ignored if `widget` is `True`.

    Returns
    -------
    None or plotly.graph_objects.Figure or ipywidgets.HBox with plotly.graph_objs._figurewidget.FigureWidget
        Return figure that can be edited or saved. See `show` parameter.
    """    
    _global_checks.global_check_import('kaleido', 'Predict Triplot')
    ## right plot
    hierarchical_clustering_dendrogram_plot_without_annotations = go.Figure(
        self._hierarchical_clustering_dendrogram
    )
    variables_order = list(
        hierarchical_clustering_dendrogram_plot_without_annotations.layout.yaxis.ticktext
    )

    ## middle plot
    (
        hierarchical_importance_dendrogram_plot_without_annotations,
        updated_dendro_traces,
    ) = plot.plot_predict_hierarchical_importance(
        hierarchical_clustering_dendrogram_plot_without_annotations,
        self.result,
        rounding_function,
        digits,
        absolute_value,
        self.type,
    )

    hierarchical_clustering_dendrogram_plot = plot.add_text_to_dendrogram(
        hierarchical_clustering_dendrogram_plot_without_annotations,
        updated_dendro_traces,
        rounding_function,
        digits,
        type="clustering",
    )

    hierarchical_importance_dendrogram_plot = plot.add_text_to_dendrogram(
        hierarchical_importance_dendrogram_plot_without_annotations,
        updated_dendro_traces,
        rounding_function,
        digits,
        type="importance",
    )

    ## left plot
    fig = plot.plot_single_aspects_importance(
        self.single_variable_importance,
        variables_order,
        rounding_function,
        digits,
        vcolors,
    )
    fig.layout["xaxis"]["range"] = (
        fig.layout["xaxis"]["range"][0],
        fig.layout["xaxis"]["range"][1] * 1.05,
    )
    m = len(variables_order)
    y_vals = [-5 - i * 10 for i in range(m)]
    fig.data[0]["y"] = y_vals

    ## triplot

    fig.add_shape(
        type="line",
        x0=0,
        x1=0,
        y0=-0.01,
        y1=1.01,
        yref="paper",
        xref="x2",
        line={"color": "#371ea3", "width": 1.5, "dash": "dot"},
    )

    min_x_imp, max_x_imp = np.Inf, -np.Inf
    for data in hierarchical_importance_dendrogram_plot["data"][::-1]:
        data["xaxis"] = "x2"
        data["hoverinfo"] = "text"
        data["line"] = {"color": "#46bac2", "width": 2}
        fig.add_trace(data)
        min_x_imp = np.min([min_x_imp, np.min(data["x"])])
        max_x_imp = np.max([max_x_imp, np.max(data["x"])])
    min_max_margin_imp = (max_x_imp - min_x_imp) * 0.15

    min_x_clust, max_x_clust = np.Inf, -np.Inf
    for data in hierarchical_clustering_dendrogram_plot["data"]:
        data["xaxis"] = "x3"
        data["hoverinfo"] = "text"
        data["line"] = {"color": "#46bac2", "width": 2}
        fig.add_trace(data)
        min_x_clust = np.min([min_x_clust, np.min(data["x"])])
        max_x_clust = np.max([max_x_clust, np.max(data["x"])])
    min_max_margin_clust = (max_x_clust - min_x_clust) * 0.15

    plot_height = 78 + 71 + m * bar_width + (m + 1) * bar_width / 4
    ticktext = plot.get_ticktext_for_plot(
        self.single_variable_importance, variables_order, abbrev_labels
    )

    fig.update_layout(
        xaxis={
            "autorange": False,
            "domain": [0, 0.33],
            "mirror": False,
            "showgrid": False,
            "showline": False,
            "zeroline": False,
            "ticks": "",
            "title_text": "Local variable importance",
        },
        xaxis2={
            "domain": [0.33, 0.66],
            "mirror": False,
            "showgrid": False,
            "showline": False,
            "zeroline": False,
            "showticklabels": True,
            "tickvals": [0],
            "ticktext": [""],
            "ticks": "",
            "title_text": "Hierarchical aspect importance",
            "autorange": False,
            "fixedrange": True,
            "range": [
                min_x_imp - min_max_margin_imp,
                max_x_imp + min_max_margin_imp,
            ],
        },
        xaxis3={
            "domain": [0.66, 0.99],
            "mirror": False,
            "showgrid": False,
            "showline": False,
            "zeroline": False,
            "showticklabels": True,
            "tickvals": [0],
            "ticktext": [""],
            "ticks": "",
            "title_text": "Hierarchical clustering",
            "autorange": False,
            "fixedrange": True,
            "range": [
                min_x_clust - min_max_margin_clust,
                max_x_clust + min_max_margin_clust,
            ],
        },
        yaxis={
            "mirror": False,
            "ticks": "",
            "fixedrange": True,
            "gridwidth": 1,
            "type": "linear",
            "tickmode": "array",
            "tickvals": y_vals,
            "ticktext": ticktext,
        },
        title_text=title,
        title_x=0.5,
        font={"color": "#371ea3"},
        template="none",
        margin={"t": 78, "b": 71, "r": 30},
        width=width,
        height=plot_height,
        showlegend=False,
        hovermode="closest"
    )

    fig, middle_point = plot._add_points_on_dendrogram_traces(fig)

    ##################################################################

    if widget:
        _global_checks.global_check_import('ipywidgets', 'Predict Triplot')
        from ipywidgets import HBox, Layout
        fig = go.FigureWidget(fig, layout={"autosize": True, "hoverdistance": 100})
        original_bar_colors = deepcopy(list(fig.data[0]["marker"]["color"]))
        original_text_colors = deepcopy(list(fig.data[0]["textfont"]["color"]))
        k = len(fig.data)
        updated_dendro_traces_in_full_figure = list(
            np.array(updated_dendro_traces) + (k - 1) / 2 + 1
        ) + list((k - 1) / 2 - np.array(updated_dendro_traces))

        def _update_childs(x, y, fig, k, selected, selected_y_cord):
            for i in range(1, k):
                if middle_point[i] == (x, y):
                    fig.data[i]["line"]["color"] = "#46bac2"
                    fig.data[i]["line"]["width"] = 3
                    fig.data[k - i]["line"]["color"] = "#46bac2"
                    fig.data[k - i]["line"]["width"] = 3
                    selected.append(i)
                    selected.append(k - i)
                    if (fig.data[i]["y"][0] + 5) % 10 == 0:
                        selected_y_cord.append((fig.data[i]["y"][0] + 5) // -10)
                    if (fig.data[i]["y"][-1] - 5) % 10 == 0:
                        selected_y_cord.append((fig.data[i]["y"][-1] + 5) // -10)
                    _update_childs(
                        fig.data[i]["x"][0],
                        fig.data[i]["y"][0],
                        fig,
                        k,
                        selected,
                        selected_y_cord,
                    )
                    _update_childs(
                        fig.data[i]["x"][-1],
                        fig.data[i]["y"][-1],
                        fig,
                        k,
                        selected,
                        selected_y_cord,
                    )

        def _update_trace(trace, points, selector):
            if len(points.point_inds) == 1:
                selected_ind = points.trace_index
                with fig.batch_update():
                    if selected_ind not in updated_dendro_traces_in_full_figure:
                        for i in range(1, k):
                            fig.data[i]["line"]["color"] = "#46bac2"
                            fig.data[i]["line"]["width"] = 2
                            fig.data[i]["textfont"]["color"] = "#371ea3"
                            fig.data[i]["textfont"]["size"] = 12
                        fig.data[0]["marker"]["color"] = original_bar_colors
                        fig.data[0]["textfont"]["color"] = original_text_colors
                    else:
                        selected = [selected_ind, k - selected_ind]
                        selected_y_cord = []
                        if (fig.data[selected_ind]["y"][0] - 5) % 10 == 0:
                            selected_y_cord.append(
                                (fig.data[selected_ind]["y"][0] + 5) // -10
                            )
                        if (fig.data[selected_ind]["y"][-1] - 5) % 10 == 0:
                            selected_y_cord.append(
                                (fig.data[selected_ind]["y"][-1] + 5) // -10
                            )
                        fig.data[selected_ind]["line"]["color"] = "#46bac2"
                        fig.data[selected_ind]["line"]["width"] = 3
                        fig.data[selected_ind]["textfont"]["color"] = "#371ea3"
                        fig.data[selected_ind]["textfont"]["size"] = 14
                        fig.data[k - selected_ind]["line"]["color"] = "#46bac2"
                        fig.data[k - selected_ind]["line"]["width"] = 3
                        fig.data[k - selected_ind]["textfont"]["color"] = "#371ea3"
                        fig.data[k - selected_ind]["textfont"]["size"] = 14
                        _update_childs(
                            fig.data[selected_ind]["x"][0],
                            fig.data[selected_ind]["y"][0],
                            fig,
                            k,
                            selected,
                            selected_y_cord,
                        )
                        _update_childs(
                            fig.data[selected_ind]["x"][-1],
                            fig.data[selected_ind]["y"][-1],
                            fig,
                            k,
                            selected,
                            selected_y_cord,
                        )
                        for i in range(1, k):
                            if i not in [selected_ind, k - selected_ind]:
                                fig.data[i]["textfont"]["color"] = "#ceced9"
                                fig.data[i]["textfont"]["size"] = 12
                                if i not in selected:
                                    fig.data[i]["line"]["color"] = "#ceced9"
                                    fig.data[i]["line"]["width"] = 1

                        bars_colors_list = deepcopy(original_bar_colors)
                        text_colors_list = deepcopy(original_text_colors)
                        for i in range(m):
                            if i not in selected_y_cord:
                                bars_colors_list[i] = "#ceced9"
                                text_colors_list[i] = "#ceced9"
                        fig.data[0]["marker"]["color"] = bars_colors_list
                        fig.data[0]["textfont"]["color"] = text_colors_list

        for i in range(1, k):
            fig.data[i].on_click(_update_trace)
        return HBox([fig], layout=Layout(overflow='scroll', width=f'{fig.layout.width}px'))
    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig