Module dalex.model_explanations

Expand source code Browse git
from ._aggregated_profiles.object import AggregatedProfiles
from ._model_performance.object import ModelPerformance
from ._variable_importance.object import VariableImportance
from ._residual_diagnostics import ResidualDiagnostics

__all__ = [
    "ModelPerformance",
    "VariableImportance",
    "AggregatedProfiles",
    "ResidualDiagnostics"
]

Classes

class AggregatedProfiles (type='partial', variables=None, variable_type='numerical', groups=None, span=0.25, center=True, random_state=None)

Calculate model-level variable profiles as Partial or Accumulated Dependence

  • Partial Dependence Profile (average across Ceteris Paribus Profiles),
  • Individual Conditional Expectation (local weighted average across CP Profiles),
  • Accumulated Local Effects (cumulated average local changes in CP Profiles).

Parameters

type : {'partial', 'accumulated', 'conditional'}
Type of model profiles (default is 'partial' for Partial Dependence Profiles).
variables : str or array_like of str, optional
Variables for which the profiles will be calculated (default is None, which means all of the variables).
variable_type : {'numerical', 'categorical'}
Calculate the profiles for numerical or categorical variables (default is 'numerical').
groups : str or array_like of str, optional
Names of categorical variables that will be used for profile grouping (default is None, which means no grouping).
span : float, optional
Smoothing coefficient used as sd for gaussian kernel (default is 0.25).
center : bool, optional
Theoretically Accumulated Profiles start at 0, but are centered to compare them with Partial Dependence Profiles (default is True, which means center around the average y_hat calculated on the data sample).
random_state : int, optional
Set seed for random number generator (default is random seed).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
mean_prediction : float
Average prediction for sampled data (using N).
raw_profiles : pd.DataFrame or None
Saved CeterisParibus object. NOTE: None if more objects were passed to the fit method.
type : {'partial', 'accumulated', 'conditional'}
Type of model profiles.
variables : array_like of str or None
Variables for which the profiles will be calculated
variable_type : {'numerical', 'categorical'}
Calculate the profiles for numerical or categorical variables.
groups : str or array_like of str or None
Names of categorical variables that will be used for profile grouping.
span : float
Smoothing coefficient used as sd for gaussian kernel.
center : bool
Theoretically Accumulated Profiles start at 0, but are centered to compare them with Partial Dependence Profiles (default is True, which means center around the average y_hat calculated on the data sample).
random_state : int or None
Set seed for random number generator.

Notes

Expand source code Browse git
class AggregatedProfiles(Explanation):
    """Calculate model-level variable profiles as Partial or Accumulated Dependence

    - Partial Dependence Profile (average across Ceteris Paribus Profiles),
    - Individual Conditional Expectation (local weighted average across CP Profiles),
    - Accumulated Local Effects (cumulated average local changes in CP Profiles).

    Parameters
    -----------
    type : {'partial', 'accumulated', 'conditional'}
        Type of model profiles
        (default is `'partial'` for Partial Dependence Profiles).
    variables : str or array_like of str, optional
        Variables for which the profiles will be calculated
        (default is `None`, which means all of the variables).
    variable_type : {'numerical', 'categorical'}
        Calculate the profiles for numerical or categorical variables
        (default is `'numerical'`).
    groups : str or array_like of str, optional
        Names of categorical variables that will be used for profile grouping
        (default is `None`, which means no grouping).
    span : float, optional
        Smoothing coefficient used as sd for gaussian kernel (default is `0.25`).
    center : bool, optional
        Theoretically Accumulated Profiles start at `0`, but are centered to compare
        them with Partial Dependence Profiles (default is `True`, which means center
        around the average y_hat calculated on the data sample).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    mean_prediction : float
        Average prediction for sampled `data` (using `N`).
    raw_profiles : pd.DataFrame or None
        Saved CeterisParibus object.
        NOTE: `None` if more objects were passed to the `fit` method.
    type : {'partial', 'accumulated', 'conditional'}
        Type of model profiles.
    variables : array_like of str or None
        Variables for which the profiles will be calculated
    variable_type : {'numerical', 'categorical'}
        Calculate the profiles for numerical or categorical variables.
    groups : str or array_like of str or None
        Names of categorical variables that will be used for profile grouping.
    span : float
        Smoothing coefficient used as sd for gaussian kernel.
    center : bool
        Theoretically Accumulated Profiles start at `0`, but are centered to compare
        them with Partial Dependence Profiles (default is `True`, which means center
        around the average y_hat calculated on the data sample).
    random_state : int or None
        Set seed for random number generator.

    Notes
    --------
    - https://pbiecek.github.io/ema/partialDependenceProfiles.html
    - https://pbiecek.github.io/ema/accumulatedLocalProfiles.html
    """

    def __init__(self,
                 type='partial',
                 variables=None,
                 variable_type='numerical',
                 groups=None,
                 span=0.25,
                 center=True,
                 random_state=None):

        checks.check_variable_type(variable_type)
        _variables = checks.check_variables(variables)
        _groups = checks.check_groups(groups)

        self.variable_type = variable_type
        self.groups = _groups
        self.type = type
        self.variables = _variables
        self.span = span
        self.center = center
        self.result = pd.DataFrame()
        self.mean_prediction = None
        self.raw_profiles = None
        self.random_state = random_state

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self,
            ceteris_paribus,
            verbose=True):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        ceteris_paribus : CeterisParibus object or array_like of CeterisParibus objects
            Profile objects to aggregate.
        verbose : bool, optional
            Print tqdm progress bar (default is `True`).

        Returns
        -----------
        None
        """
        # are there any other cp?
        from dalex.predict_explanations import CeterisParibus
        if isinstance(ceteris_paribus, CeterisParibus):  # allow for ceteris_paribus to be a single element
            all_profiles = ceteris_paribus.result.copy()
            all_observations = ceteris_paribus.new_observation.copy()
            self.raw_profiles = deepcopy(ceteris_paribus)
        elif isinstance(ceteris_paribus, (list, tuple)):  # ceteris_paribus as tuple or array
            all_profiles = None
            all_observations = None
            for cp in ceteris_paribus:
                _global_checks.global_check_object_class(cp, CeterisParibus)
                all_profiles = pd.concat([all_profiles, cp.result.copy()])
                all_observations = pd.concat([all_observations, cp.new_observation.copy()])
        else:
            _global_checks.global_raise_objects_class(ceteris_paribus, CeterisParibus)

        all_profiles, vnames = utils.prepare_numerical_categorical(all_profiles, self.variables, self.variable_type)

        # select only suitable variables
        all_profiles = all_profiles.loc[all_profiles['_vname_'].isin(vnames), :]

        all_profiles = utils.prepare_x(all_profiles, self.variable_type)

        self.mean_prediction = all_observations['_yhat_'].mean()

        self.result = utils.aggregate_profiles(all_profiles,
                                               self.mean_prediction,
                                               self.type,
                                               self.groups,
                                               self.center,
                                               self.span,
                                               verbose)

    def plot(self,
             objects=None,
             geom='aggregates',
             variables=None,
             center=True,
             size=2,
             alpha=1,
             color='_label_',
             facet_ncol=2,
             title="Aggregated Profiles",
             y_title='prediction',
             horizontal_spacing=0.05,
             vertical_spacing=None,
             show=True):
        """Plot the Aggregated Profiles explanation

        Parameters
        -----------
        objects : AggregatedProfiles object or array_like of AggregatedProfiles objects
            Additional objects to plot in subplots (default is `None`).
        geom : {'aggregates', 'profiles', 'bars'}
            If `'profiles'` then raw profiles will be plotted in the background,
            'bars' overrides the `_x_` column type and uses barplots for categorical data
            (default is `'aggregates'`, which means plot only aggregated profiles).
            NOTE: It is useful to use small values of the `N` parameter in object creation
            before using `'profiles'`, because of plot performance and clarity (e.g. `100`).
        variables : str or array_like of str, optional
            Variables for which the profiles will be calculated
            (default is `None`, which means all of the variables).
        center : bool, optional
            Theoretically Accumulated Profiles start at `0`, but are centered to compare
            them with Partial Dependence Profiles (default is `True`, which means center
            around the average y_hat calculated on the data sample).
        size : float, optional
            Width of lines in px (default is `2`).
        alpha : float <0, 1>, optional
            Opacity of lines (default is `1`).
        color : str, optional
            Variable name used for grouping
            (default is `'_label_'`, which groups by models).
        facet_ncol : int, optional
            Number of columns on the plot grid (default is `2`).
        title : str, optional
            Title of the plot (default is `"Aggregated Profiles"`).
        y_title : str, optional
            Title of the y axis (default is `"prediction"`).
        horizontal_spacing : float <0, 1>, optional
            Ratio of horizontal space between the plots (default is `0.05`).
        vertical_spacing : float <0, 1>, optional
            Ratio of vertical space between the plots (default is `0.3/number of rows`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can 
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -----------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        if geom not in ("aggregates", "profiles", "bars"):
            raise TypeError("geom should be one of {'aggregates', 'profiles', 'bars'}")
        if isinstance(variables, str):
            variables = (variables,)

        # are there any other objects to plot?
        if objects is None:
            _result_df = self.result.assign(_mp_=self.mean_prediction if center else 0)
        elif isinstance(objects, self.__class__):  # allow for objects to be a single element
            _result_df = pd.concat([self.result.assign(_mp_=self.mean_prediction if center else 0),
                                    objects.result.assign(_mp_=objects.mean_prediction if center else 0)])
        elif isinstance(objects, (list, tuple)):  # objects as tuple or array
            _result_df = self.result.assign(_mp_=self.mean_prediction if center else 0)
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _result_df = pd.concat([_result_df, ob.result.assign(_mp_=ob.mean_prediction if center else 0)])
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        # variables to use
        all_variables = _result_df['_vname_'].dropna().unique().tolist()

        if variables is not None:
            all_variables = _global_utils.intersect_unsorted(variables, all_variables)
            if len(all_variables) == 0:
                raise TypeError("variables do not overlap with " + ''.join(variables))

            _result_df = _result_df.loc[_result_df['_vname_'].isin(all_variables), :]

        #  calculate y axis range to allow for fixedrange True
        dl = _result_df['_yhat_'].to_numpy()
        min_max_margin = dl.ptp() * 0.10
        min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

        is_x_numeric = False if geom == 'bars' else pd.api.types.is_numeric_dtype(_result_df['_x_'])
        n = len(all_variables)

        facet_nrow = int(np.ceil(n / facet_ncol))
        if vertical_spacing is None:
            vertical_spacing = 0.3 / facet_nrow
        plot_height = 78 + 71 + facet_nrow * (280 + 60)
        hovermode, render_mode = 'x unified', 'svg'

        # color = '_groups_' doesnt make much sense for multiple AP objects
        m = len(_result_df[color].dropna().unique())

        if is_x_numeric:
            if geom == 'profiles' and self.raw_profiles is not None:
                render_mode = 'webgl'

            fig = px.line(_result_df,
                          x="_x_", y="_yhat_", color=color, facet_col="_vname_",
                          category_orders={"_vname_": list(all_variables)},
                          labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'},  # , color: 'group'},
                          hover_name=color,
                          hover_data={'_yhat_': ':.3f', '_mp_': ':.3f',
                                      color: False, '_vname_': False, '_x_': False},
                          facet_col_wrap=facet_ncol,
                          facet_row_spacing=vertical_spacing,
                          facet_col_spacing=horizontal_spacing,
                          template="none",
                          render_mode=render_mode,
                          color_discrete_sequence=_theme.get_default_colors(m, 'line')) \
                    .update_traces(dict(line_width=size, opacity=alpha)) \
                    .update_xaxes({'matches': None, 'showticklabels': True,
                                   'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                                   'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}) \
                    .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                                   'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
                                   'range': min_max})

            if geom == 'profiles' and self.raw_profiles is not None:
                fig.update_traces(dict(line_width=2*size, opacity=1))
                fig_cp = self.raw_profiles.plot(variables=list(all_variables),
                                                facet_ncol=facet_ncol,
                                                show_observations=False, show=False) \
                    .update_traces(dict(line_width=1, opacity=0.5, line_color='#ceced9'))

                for _, value in enumerate(fig.data):
                    fig_cp.add_trace(value)
                hovermode = False
                fig = fig_cp
        else:
            _result_df = _result_df.assign(_diff_=lambda x: x['_yhat_'] - x['_mp_'])
            mp_format = ':.3f'
            if not center:
                min_max = [np.min([min_max[0], 0]), np.max([min_max[1], 0])]
                mp_format = False

            fig = px.bar(_result_df,
                         x="_x_", y="_diff_", color="_label_", facet_col="_vname_",
                         category_orders={"_vname_": list(all_variables)},
                         labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'},  # , color: 'group'},
                         hover_name=color,
                         base="_mp_",
                         hover_data={'_yhat_': ':.3f', '_mp_': mp_format, '_diff_': False,
                                     color: False, '_vname_': False, '_x_': False},
                         facet_col_wrap=facet_ncol,
                         facet_row_spacing=vertical_spacing,
                         facet_col_spacing=horizontal_spacing,
                         template="none",
                         color_discrete_sequence=_theme.get_default_colors(m, 'line'),  # bar was forgotten
                         barmode='group')  \
                    .update_xaxes({'matches': None, 'showticklabels': True,
                                   'type': 'category', 'gridwidth': 2, 'automargin': True,  # autorange="reversed"
                                   'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}) \
                    .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                                   'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
                                   'range': min_max})

            for _, bar in enumerate(fig.data):
                fig.add_hline(y=bar.base[0], layer='below',
                              line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'})

        fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, hovermode)

        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, ceteris_paribus, verbose=True)

Calculate the result of explanation

Fit method makes calculations in place and changes the attributes.

Parameters

ceteris_paribus : CeterisParibus object or array_like of CeterisParibus objects
Profile objects to aggregate.
verbose : bool, optional
Print tqdm progress bar (default is True).

Returns

None
 
Expand source code Browse git
def fit(self,
        ceteris_paribus,
        verbose=True):
    """Calculate the result of explanation

    Fit method makes calculations in place and changes the attributes.

    Parameters
    -----------
    ceteris_paribus : CeterisParibus object or array_like of CeterisParibus objects
        Profile objects to aggregate.
    verbose : bool, optional
        Print tqdm progress bar (default is `True`).

    Returns
    -----------
    None
    """
    # are there any other cp?
    from dalex.predict_explanations import CeterisParibus
    if isinstance(ceteris_paribus, CeterisParibus):  # allow for ceteris_paribus to be a single element
        all_profiles = ceteris_paribus.result.copy()
        all_observations = ceteris_paribus.new_observation.copy()
        self.raw_profiles = deepcopy(ceteris_paribus)
    elif isinstance(ceteris_paribus, (list, tuple)):  # ceteris_paribus as tuple or array
        all_profiles = None
        all_observations = None
        for cp in ceteris_paribus:
            _global_checks.global_check_object_class(cp, CeterisParibus)
            all_profiles = pd.concat([all_profiles, cp.result.copy()])
            all_observations = pd.concat([all_observations, cp.new_observation.copy()])
    else:
        _global_checks.global_raise_objects_class(ceteris_paribus, CeterisParibus)

    all_profiles, vnames = utils.prepare_numerical_categorical(all_profiles, self.variables, self.variable_type)

    # select only suitable variables
    all_profiles = all_profiles.loc[all_profiles['_vname_'].isin(vnames), :]

    all_profiles = utils.prepare_x(all_profiles, self.variable_type)

    self.mean_prediction = all_observations['_yhat_'].mean()

    self.result = utils.aggregate_profiles(all_profiles,
                                           self.mean_prediction,
                                           self.type,
                                           self.groups,
                                           self.center,
                                           self.span,
                                           verbose)
def plot(self, objects=None, geom='aggregates', variables=None, center=True, size=2, alpha=1, color='_label_', facet_ncol=2, title='Aggregated Profiles', y_title='prediction', horizontal_spacing=0.05, vertical_spacing=None, show=True)

Plot the Aggregated Profiles explanation

Parameters

objects : AggregatedProfiles object or array_like of AggregatedProfiles objects
Additional objects to plot in subplots (default is None).
geom : {'aggregates', 'profiles', 'bars'}
If 'profiles' then raw profiles will be plotted in the background, 'bars' overrides the _x_ column type and uses barplots for categorical data (default is 'aggregates', which means plot only aggregated profiles). NOTE: It is useful to use small values of the N parameter in object creation before using 'profiles', because of plot performance and clarity (e.g. 100).
variables : str or array_like of str, optional
Variables for which the profiles will be calculated (default is None, which means all of the variables).
center : bool, optional
Theoretically Accumulated Profiles start at 0, but are centered to compare them with Partial Dependence Profiles (default is True, which means center around the average y_hat calculated on the data sample).
size : float, optional
Width of lines in px (default is 2).
alpha : float <0, 1>, optional
Opacity of lines (default is 1).
color : str, optional
Variable name used for grouping (default is '_label_', which groups by models).
facet_ncol : int, optional
Number of columns on the plot grid (default is 2).
title : str, optional
Title of the plot (default is "Aggregated Profiles").
y_title : str, optional
Title of the y axis (default is "prediction").
horizontal_spacing : float <0, 1>, optional
Ratio of horizontal space between the plots (default is 0.05).
vertical_spacing : float <0, 1>, optional
Ratio of vertical space between the plots (default is 0.3/number of rows).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(self,
         objects=None,
         geom='aggregates',
         variables=None,
         center=True,
         size=2,
         alpha=1,
         color='_label_',
         facet_ncol=2,
         title="Aggregated Profiles",
         y_title='prediction',
         horizontal_spacing=0.05,
         vertical_spacing=None,
         show=True):
    """Plot the Aggregated Profiles explanation

    Parameters
    -----------
    objects : AggregatedProfiles object or array_like of AggregatedProfiles objects
        Additional objects to plot in subplots (default is `None`).
    geom : {'aggregates', 'profiles', 'bars'}
        If `'profiles'` then raw profiles will be plotted in the background,
        'bars' overrides the `_x_` column type and uses barplots for categorical data
        (default is `'aggregates'`, which means plot only aggregated profiles).
        NOTE: It is useful to use small values of the `N` parameter in object creation
        before using `'profiles'`, because of plot performance and clarity (e.g. `100`).
    variables : str or array_like of str, optional
        Variables for which the profiles will be calculated
        (default is `None`, which means all of the variables).
    center : bool, optional
        Theoretically Accumulated Profiles start at `0`, but are centered to compare
        them with Partial Dependence Profiles (default is `True`, which means center
        around the average y_hat calculated on the data sample).
    size : float, optional
        Width of lines in px (default is `2`).
    alpha : float <0, 1>, optional
        Opacity of lines (default is `1`).
    color : str, optional
        Variable name used for grouping
        (default is `'_label_'`, which groups by models).
    facet_ncol : int, optional
        Number of columns on the plot grid (default is `2`).
    title : str, optional
        Title of the plot (default is `"Aggregated Profiles"`).
    y_title : str, optional
        Title of the y axis (default is `"prediction"`).
    horizontal_spacing : float <0, 1>, optional
        Ratio of horizontal space between the plots (default is `0.05`).
    vertical_spacing : float <0, 1>, optional
        Ratio of vertical space between the plots (default is `0.3/number of rows`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can 
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -----------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    if geom not in ("aggregates", "profiles", "bars"):
        raise TypeError("geom should be one of {'aggregates', 'profiles', 'bars'}")
    if isinstance(variables, str):
        variables = (variables,)

    # are there any other objects to plot?
    if objects is None:
        _result_df = self.result.assign(_mp_=self.mean_prediction if center else 0)
    elif isinstance(objects, self.__class__):  # allow for objects to be a single element
        _result_df = pd.concat([self.result.assign(_mp_=self.mean_prediction if center else 0),
                                objects.result.assign(_mp_=objects.mean_prediction if center else 0)])
    elif isinstance(objects, (list, tuple)):  # objects as tuple or array
        _result_df = self.result.assign(_mp_=self.mean_prediction if center else 0)
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _result_df = pd.concat([_result_df, ob.result.assign(_mp_=ob.mean_prediction if center else 0)])
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    # variables to use
    all_variables = _result_df['_vname_'].dropna().unique().tolist()

    if variables is not None:
        all_variables = _global_utils.intersect_unsorted(variables, all_variables)
        if len(all_variables) == 0:
            raise TypeError("variables do not overlap with " + ''.join(variables))

        _result_df = _result_df.loc[_result_df['_vname_'].isin(all_variables), :]

    #  calculate y axis range to allow for fixedrange True
    dl = _result_df['_yhat_'].to_numpy()
    min_max_margin = dl.ptp() * 0.10
    min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

    is_x_numeric = False if geom == 'bars' else pd.api.types.is_numeric_dtype(_result_df['_x_'])
    n = len(all_variables)

    facet_nrow = int(np.ceil(n / facet_ncol))
    if vertical_spacing is None:
        vertical_spacing = 0.3 / facet_nrow
    plot_height = 78 + 71 + facet_nrow * (280 + 60)
    hovermode, render_mode = 'x unified', 'svg'

    # color = '_groups_' doesnt make much sense for multiple AP objects
    m = len(_result_df[color].dropna().unique())

    if is_x_numeric:
        if geom == 'profiles' and self.raw_profiles is not None:
            render_mode = 'webgl'

        fig = px.line(_result_df,
                      x="_x_", y="_yhat_", color=color, facet_col="_vname_",
                      category_orders={"_vname_": list(all_variables)},
                      labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'},  # , color: 'group'},
                      hover_name=color,
                      hover_data={'_yhat_': ':.3f', '_mp_': ':.3f',
                                  color: False, '_vname_': False, '_x_': False},
                      facet_col_wrap=facet_ncol,
                      facet_row_spacing=vertical_spacing,
                      facet_col_spacing=horizontal_spacing,
                      template="none",
                      render_mode=render_mode,
                      color_discrete_sequence=_theme.get_default_colors(m, 'line')) \
                .update_traces(dict(line_width=size, opacity=alpha)) \
                .update_xaxes({'matches': None, 'showticklabels': True,
                               'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                               'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}) \
                .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                               'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
                               'range': min_max})

        if geom == 'profiles' and self.raw_profiles is not None:
            fig.update_traces(dict(line_width=2*size, opacity=1))
            fig_cp = self.raw_profiles.plot(variables=list(all_variables),
                                            facet_ncol=facet_ncol,
                                            show_observations=False, show=False) \
                .update_traces(dict(line_width=1, opacity=0.5, line_color='#ceced9'))

            for _, value in enumerate(fig.data):
                fig_cp.add_trace(value)
            hovermode = False
            fig = fig_cp
    else:
        _result_df = _result_df.assign(_diff_=lambda x: x['_yhat_'] - x['_mp_'])
        mp_format = ':.3f'
        if not center:
            min_max = [np.min([min_max[0], 0]), np.max([min_max[1], 0])]
            mp_format = False

        fig = px.bar(_result_df,
                     x="_x_", y="_diff_", color="_label_", facet_col="_vname_",
                     category_orders={"_vname_": list(all_variables)},
                     labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'},  # , color: 'group'},
                     hover_name=color,
                     base="_mp_",
                     hover_data={'_yhat_': ':.3f', '_mp_': mp_format, '_diff_': False,
                                 color: False, '_vname_': False, '_x_': False},
                     facet_col_wrap=facet_ncol,
                     facet_row_spacing=vertical_spacing,
                     facet_col_spacing=horizontal_spacing,
                     template="none",
                     color_discrete_sequence=_theme.get_default_colors(m, 'line'),  # bar was forgotten
                     barmode='group')  \
                .update_xaxes({'matches': None, 'showticklabels': True,
                               'type': 'category', 'gridwidth': 2, 'automargin': True,  # autorange="reversed"
                               'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}) \
                .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                               'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
                               'range': min_max})

        for _, bar in enumerate(fig.data):
            fig.add_hline(y=bar.base[0], layer='below',
                          line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'})

    fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, hovermode)

    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
class ModelPerformance (model_type, cutoff=0.5)

Calculate model-level model performance measures

Parameters

model_type : {'regression', 'classification'}
Model task type that is used to choose the proper performance measures.
cutoff : float, optional
Cutoff for predictions in classification models. Needed for measures like recall, precision, acc, f1 (default is 0.5).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
residuals : pd.DataFrame
Residuals for data.
model_type : {'regression', 'classification'}
Model task type that is used to choose the proper performance measures.
cutoff : float
Cutoff for predictions in classification models.

Notes

Expand source code Browse git
class ModelPerformance(Explanation):
    """Calculate model-level model performance measures

    Parameters
    -----------
    model_type : {'regression', 'classification'}
        Model task type that is used to choose the proper performance measures.
    cutoff : float, optional
        Cutoff for predictions in classification models. Needed for measures like
        recall, precision, acc, f1 (default is `0.5`).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    residuals : pd.DataFrame
        Residuals for `data`.
    model_type : {'regression', 'classification'}
        Model task type that is used to choose the proper performance measures.
    cutoff : float
        Cutoff for predictions in classification models.

    Notes
    --------
    - https://pbiecek.github.io/ema/modelPerformance.html
    """
    def __init__(self,
                 model_type,
                 cutoff=0.5):

        self.cutoff = cutoff
        self.model_type = model_type
        self.result = pd.DataFrame()
        self.residuals = None

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self, explainer):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.

        Returns
        -----------
        None
        """

        if explainer.y_hat is not None:
            y_pred = explainer.y_hat
        else:
            y_pred = explainer.predict(explainer.data)

        if explainer.residuals is not None:
            _residuals = explainer.residuals
        else:
            _residuals = explainer.residual(explainer.data, explainer.y)

        y_true = explainer.y

        if self.model_type == 'regression':
            _mse = utils.mse(y_pred, y_true)
            _rmse = utils.rmse(y_pred, y_true)
            _r2 = utils.r2(y_pred, y_true)
            _mae = utils.mae(y_pred, y_true)
            _mad = utils.mad(y_pred, y_true)

            self.result = pd.DataFrame(
                {
                    'mse': [_mse],
                    'rmse': [_rmse],
                    'r2': [_r2],
                    'mae': [_mae],
                    'mad': [_mad]
                }, index=[explainer.label])
        elif self.model_type == 'classification':
            tp = ((y_true == 1) * (y_pred >= self.cutoff)).sum()
            fp = ((y_true == 0) * (y_pred >= self.cutoff)).sum()
            tn = ((y_true == 0) * (y_pred < self.cutoff)).sum()
            fn = ((y_true == 1) * (y_pred < self.cutoff)).sum()

            _recall = utils.recall(tp, fp, tn, fn)
            _precision = utils.precision(tp, fp, tn, fn)
            _f1 = utils.f1(tp, fp, tn, fn)
            _accuracy = utils.accuracy(tp, fp, tn, fn)
            _auc = utils.auc(y_pred, y_true)

            self.result = pd.DataFrame({
                'recall': [_recall],
                'precision': [_precision],
                'f1': [_f1],
                'accuracy': [_accuracy],
                'auc': [_auc]
            }, index=[explainer.label])
        else:
            raise ValueError("'model_type' must be 'regression' or 'classification'")

        _residuals = pd.DataFrame({
            'y_hat': y_pred,
            'y': y_true,
            'residuals': _residuals,
            'label': explainer.label
        })

        self.residuals = _residuals

    def plot(self,
             objects=None,
             geom="ecdf",
             title=None,
             show=False):
        """Plot the Model Performance explanation

        Parameters
        -----------
        objects : ModelPerformance object or array_like of ModelPerformance objects
            Additional objects to plot (default is `None`).
        geom: {'ecdf', 'roc', 'lift'}
            Type of plot determines how residuals shall be summarized.
        title : str, optional
            Title of the plot (default depends on the `type` attribute).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can 
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -----------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        if geom not in ("ecdf", "roc", "lift"):
            raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}")
        
        # are there any other objects to plot?
        if objects is None:
            _df_list = [self.residuals.copy()]
        elif isinstance(objects, self.__class__):  # allow for objects to be a single element
            _df_list = [self.residuals.copy(), objects.residuals.copy()]
        elif isinstance(objects, (list, tuple)):  # objects as tuple or array
            _df_list = [self.residuals.copy()]
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _df_list += [ob.residuals.copy()]
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        colors = _theme.get_default_colors(len(_df_list), 'line')
        
        if geom == 'ecdf':
            fig = plot.plot_ecdf(_df_list, colors, title)
        elif geom == 'roc':
            fig = plot.plot_roc(_df_list, colors, title)
        elif geom == 'lift':
            fig = plot.plot_lift(_df_list, colors, title)
        else:
            raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}")

        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, explainer)

Calculate the result of explanation

Fit method makes calculations in place and changes the attributes.

Parameters

explainer : Explainer object
Model wrapper created using the Explainer class.

Returns

None
 
Expand source code Browse git
def fit(self, explainer):
    """Calculate the result of explanation

    Fit method makes calculations in place and changes the attributes.

    Parameters
    -----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.

    Returns
    -----------
    None
    """

    if explainer.y_hat is not None:
        y_pred = explainer.y_hat
    else:
        y_pred = explainer.predict(explainer.data)

    if explainer.residuals is not None:
        _residuals = explainer.residuals
    else:
        _residuals = explainer.residual(explainer.data, explainer.y)

    y_true = explainer.y

    if self.model_type == 'regression':
        _mse = utils.mse(y_pred, y_true)
        _rmse = utils.rmse(y_pred, y_true)
        _r2 = utils.r2(y_pred, y_true)
        _mae = utils.mae(y_pred, y_true)
        _mad = utils.mad(y_pred, y_true)

        self.result = pd.DataFrame(
            {
                'mse': [_mse],
                'rmse': [_rmse],
                'r2': [_r2],
                'mae': [_mae],
                'mad': [_mad]
            }, index=[explainer.label])
    elif self.model_type == 'classification':
        tp = ((y_true == 1) * (y_pred >= self.cutoff)).sum()
        fp = ((y_true == 0) * (y_pred >= self.cutoff)).sum()
        tn = ((y_true == 0) * (y_pred < self.cutoff)).sum()
        fn = ((y_true == 1) * (y_pred < self.cutoff)).sum()

        _recall = utils.recall(tp, fp, tn, fn)
        _precision = utils.precision(tp, fp, tn, fn)
        _f1 = utils.f1(tp, fp, tn, fn)
        _accuracy = utils.accuracy(tp, fp, tn, fn)
        _auc = utils.auc(y_pred, y_true)

        self.result = pd.DataFrame({
            'recall': [_recall],
            'precision': [_precision],
            'f1': [_f1],
            'accuracy': [_accuracy],
            'auc': [_auc]
        }, index=[explainer.label])
    else:
        raise ValueError("'model_type' must be 'regression' or 'classification'")

    _residuals = pd.DataFrame({
        'y_hat': y_pred,
        'y': y_true,
        'residuals': _residuals,
        'label': explainer.label
    })

    self.residuals = _residuals
def plot(self, objects=None, geom='ecdf', title=None, show=False)

Plot the Model Performance explanation

Parameters

objects : ModelPerformance object or array_like of ModelPerformance objects
Additional objects to plot (default is None).
geom : {'ecdf', 'roc', 'lift'}
Type of plot determines how residuals shall be summarized.
title : str, optional
Title of the plot (default depends on the type attribute).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(self,
         objects=None,
         geom="ecdf",
         title=None,
         show=False):
    """Plot the Model Performance explanation

    Parameters
    -----------
    objects : ModelPerformance object or array_like of ModelPerformance objects
        Additional objects to plot (default is `None`).
    geom: {'ecdf', 'roc', 'lift'}
        Type of plot determines how residuals shall be summarized.
    title : str, optional
        Title of the plot (default depends on the `type` attribute).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can 
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -----------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    if geom not in ("ecdf", "roc", "lift"):
        raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}")
    
    # are there any other objects to plot?
    if objects is None:
        _df_list = [self.residuals.copy()]
    elif isinstance(objects, self.__class__):  # allow for objects to be a single element
        _df_list = [self.residuals.copy(), objects.residuals.copy()]
    elif isinstance(objects, (list, tuple)):  # objects as tuple or array
        _df_list = [self.residuals.copy()]
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _df_list += [ob.residuals.copy()]
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    colors = _theme.get_default_colors(len(_df_list), 'line')
    
    if geom == 'ecdf':
        fig = plot.plot_ecdf(_df_list, colors, title)
    elif geom == 'roc':
        fig = plot.plot_roc(_df_list, colors, title)
    elif geom == 'lift':
        fig = plot.plot_lift(_df_list, colors, title)
    else:
        raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}")

    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
class ResidualDiagnostics (variables=None)

Calculate model-level residuals diagnostics

Parameters

variables : str or array_like of str, optional
Variables for which the profiles will be calculated (default is None, which means all of the variables).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
variables : array_like of str or None
Variables for which the profiles will be calculated.

Notes

Expand source code Browse git
class ResidualDiagnostics(Explanation):
    """Calculate model-level residuals diagnostics

    Parameters
    -----------
    variables : str or array_like of str, optional
        Variables for which the profiles will be calculated
        (default is `None`, which means all of the variables).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    variables : array_like of str or None
        Variables for which the profiles will be calculated.

    Notes
    --------
    - https://pbiecek.github.io/ema/residualDiagnostic.html
    """
    def __init__(self,
                 variables=None):

        _variables = checks.check_variables(variables)

        self.result = pd.DataFrame()
        self.variables = _variables

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self, explainer):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.

        Returns
        -----------
        None
        """
        result = explainer.data.copy()

        # if variables = NULL then all variables are added
        # otherwise only selected
        if self.variables is not None:
            result = result.loc[:, _global_utils.intersect_unsorted(self.variables, result.columns)]
        # is there target
        if explainer.y is not None:
            result = result.assign(y=explainer.y)
        # are there predictions - add y_hat to the Explainer for the future
        if explainer.y_hat is None:
            explainer.y_hat = explainer.predict(explainer.data)
        # are there residuals - add residuals to the Explainer for the future
        if explainer.residuals is None:
            explainer.residuals = explainer.residual(explainer.data, explainer.y)

        self.result = result.assign(
            y_hat=explainer.y_hat,
            residuals=explainer.residuals,
            abs_residuals=np.abs(explainer.residuals),
            label=explainer.label,
            ids=np.arange(result.shape[0])+1
        )

    def plot(self,
             objects=None,
             variable="y_hat",
             yvariable="residuals",
             smooth=True,
             line_width=2,
             marker_size=3,
             title="Residual Diagnostics",
             N=50000,
             show=True):
        """Plot the Residual Diagnostics explanation

        Parameters
        ----------
        objects : ResidualDiagnostics object or array_like of ResidualDiagnostics objects
            Additional objects to plot (default is `None`).
        variable : str, optional
            Name of the variable from the `result` attribute to appear on the OX axis
            (default is `'y_hat'`).
        yvariable : str, optional
            Name of the variable from the `result` attribute to appear on the OY axis
            (default is `'residuals'`).
        smooth : bool, optional
            Add the smooth line (default is `True`).
        line_width : float, optional
            Width of lines in px (default is `2`).
        marker_size : float, optional
            Size of points (default is `3`).
        title : str, optional
            Title of the plot (default depends on the `type` attribute).
        N : int, optional
            Number of observations that will be sampled from the `result` attribute before
            calculating the smooth line. This is for performance issues with large data.
            `None` means take all `result` (default is `50_000`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can 
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -----------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        _global_checks.global_check_import('statsmodels', 'smoothing line')

        # are there any other objects to plot?
        if objects is None:
            _df_list = [self.result.copy()]
        elif isinstance(objects, self.__class__):  # allow for objects to be a single element
            _df_list = [self.result.copy(), objects.result.copy()]
        elif isinstance(objects, (list, tuple)):  # objects as tuple or array
            _df_list = [self.result.copy()]
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _df_list += [ob.result.copy()]
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        _df = pd.concat(_df_list)

        if isinstance(N, int) and smooth:
            if N < _df.shape[0]:
                _df = _df.sample(N, random_state=0, replace=False)

        fig = px.scatter(_df,
                         x=variable,
                         y=yvariable,
                         hover_name='ids',
                         color="label",
                         trendline="lowess" if smooth else None,
                         color_discrete_sequence=_theme.get_default_colors(len(_df_list), 'line')) \
                .update_traces(dict(marker_size=marker_size, line_width=line_width))

        # wait for https://github.com/plotly/plotly.py/pull/2558 to add hline to the plot

        fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside',
                          'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': yvariable})

        fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
                          'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': variable})

        fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
                          margin={'t': 78, 'b': 71, 'r': 30})

        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, explainer)

Calculate the result of explanation

Fit method makes calculations in place and changes the attributes.

Parameters

explainer : Explainer object
Model wrapper created using the Explainer class.

Returns

None
 
Expand source code Browse git
def fit(self, explainer):
    """Calculate the result of explanation

    Fit method makes calculations in place and changes the attributes.

    Parameters
    -----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.

    Returns
    -----------
    None
    """
    result = explainer.data.copy()

    # if variables = NULL then all variables are added
    # otherwise only selected
    if self.variables is not None:
        result = result.loc[:, _global_utils.intersect_unsorted(self.variables, result.columns)]
    # is there target
    if explainer.y is not None:
        result = result.assign(y=explainer.y)
    # are there predictions - add y_hat to the Explainer for the future
    if explainer.y_hat is None:
        explainer.y_hat = explainer.predict(explainer.data)
    # are there residuals - add residuals to the Explainer for the future
    if explainer.residuals is None:
        explainer.residuals = explainer.residual(explainer.data, explainer.y)

    self.result = result.assign(
        y_hat=explainer.y_hat,
        residuals=explainer.residuals,
        abs_residuals=np.abs(explainer.residuals),
        label=explainer.label,
        ids=np.arange(result.shape[0])+1
    )
def plot(self, objects=None, variable='y_hat', yvariable='residuals', smooth=True, line_width=2, marker_size=3, title='Residual Diagnostics', N=50000, show=True)

Plot the Residual Diagnostics explanation

Parameters

objects : ResidualDiagnostics object or array_like of ResidualDiagnostics objects
Additional objects to plot (default is None).
variable : str, optional
Name of the variable from the result attribute to appear on the OX axis (default is 'y_hat').
yvariable : str, optional
Name of the variable from the result attribute to appear on the OY axis (default is 'residuals').
smooth : bool, optional
Add the smooth line (default is True).
line_width : float, optional
Width of lines in px (default is 2).
marker_size : float, optional
Size of points (default is 3).
title : str, optional
Title of the plot (default depends on the type attribute).
N : int, optional
Number of observations that will be sampled from the result attribute before calculating the smooth line. This is for performance issues with large data. None means take all result (default is 50_000).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(self,
         objects=None,
         variable="y_hat",
         yvariable="residuals",
         smooth=True,
         line_width=2,
         marker_size=3,
         title="Residual Diagnostics",
         N=50000,
         show=True):
    """Plot the Residual Diagnostics explanation

    Parameters
    ----------
    objects : ResidualDiagnostics object or array_like of ResidualDiagnostics objects
        Additional objects to plot (default is `None`).
    variable : str, optional
        Name of the variable from the `result` attribute to appear on the OX axis
        (default is `'y_hat'`).
    yvariable : str, optional
        Name of the variable from the `result` attribute to appear on the OY axis
        (default is `'residuals'`).
    smooth : bool, optional
        Add the smooth line (default is `True`).
    line_width : float, optional
        Width of lines in px (default is `2`).
    marker_size : float, optional
        Size of points (default is `3`).
    title : str, optional
        Title of the plot (default depends on the `type` attribute).
    N : int, optional
        Number of observations that will be sampled from the `result` attribute before
        calculating the smooth line. This is for performance issues with large data.
        `None` means take all `result` (default is `50_000`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can 
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -----------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    _global_checks.global_check_import('statsmodels', 'smoothing line')

    # are there any other objects to plot?
    if objects is None:
        _df_list = [self.result.copy()]
    elif isinstance(objects, self.__class__):  # allow for objects to be a single element
        _df_list = [self.result.copy(), objects.result.copy()]
    elif isinstance(objects, (list, tuple)):  # objects as tuple or array
        _df_list = [self.result.copy()]
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _df_list += [ob.result.copy()]
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    _df = pd.concat(_df_list)

    if isinstance(N, int) and smooth:
        if N < _df.shape[0]:
            _df = _df.sample(N, random_state=0, replace=False)

    fig = px.scatter(_df,
                     x=variable,
                     y=yvariable,
                     hover_name='ids',
                     color="label",
                     trendline="lowess" if smooth else None,
                     color_discrete_sequence=_theme.get_default_colors(len(_df_list), 'line')) \
            .update_traces(dict(marker_size=marker_size, line_width=line_width))

    # wait for https://github.com/plotly/plotly.py/pull/2558 to add hline to the plot

    fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside',
                      'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': yvariable})

    fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
                      'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': variable})

    fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
                      margin={'t': 78, 'b': 71, 'r': 30})

    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
class VariableImportance (loss_function='rmse', type='variable_importance', N=1000, B=10, variables=None, variable_groups=None, keep_raw_permutations=True, processes=1, random_state=None)

Calculate model-level variable importance

Parameters

loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
If string, then such loss function will be used to assess variable importance (default is 'rmse' or '1-auc', depends onmodel_type` attribute).
type : {'variable_importance', 'ratio', 'difference'}, optional
Type of transformation that will be applied to dropout loss.
N : int, optional
Number of observations that will be sampled from the data attribute before the calculation of variable importance. None means all data (default is 1000).
B : int, optional
Number of permutation rounds to perform on each variable (default is 10).
variables : array_like of str, optional
Variables for which the importance will be calculated (default is None, which means all of the variables). NOTE: Ignored if variable_groups is not None.
variable_groups : dict of lists, optional
Group the variables to calculate their joint variable importance e.g. {'X': ['x1', 'x2'], 'Y': ['y1', 'y2']} (default is None).
keep_raw_permutations : bool, optional
Save results for all permutation rounds (default is True).
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
loss_function : function
Loss function used to assess the variable importance.
type : {'variable_importance', 'ratio', 'difference'}
Type of transformation that will be applied to dropout loss.
N : int
Number of observations that will be sampled from the data attribute before the calculation of variable importance.
B : int
Number of permutation rounds to perform on each variable.
variables : array_like of str or None
Variables for which the importance will be calculated
variable_groups : dict of lists or None
Grouped variables to calculate their joint variable importance.
keep_raw_permutations : bool
Save the results for all permutation rounds.
permutation : pd.DataFrame or None
The results for all permutation rounds.
processes : int
Number of parallel processes to use in calculations. Iterated over B.
random_state : int or None
Set seed for random number generator.

Notes

Expand source code Browse git
class VariableImportance(Explanation):
    """Calculate model-level variable importance

    Parameters
    -----------
    loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
        If string, then such loss function will be used to assess variable importance
        (default is `'rmse'` or `'1-auc', depends on `model_type` attribute).
    type : {'variable_importance', 'ratio', 'difference'}, optional
        Type of transformation that will be applied to dropout loss.
    N : int, optional
        Number of observations that will be sampled from the `data` attribute before
        the calculation of variable importance. 
        `None` means all `data` (default is `1000`).
    B : int, optional
        Number of permutation rounds to perform on each variable (default is `10`).
    variables : array_like of str, optional
        Variables for which the importance will be calculated
        (default is `None`, which means all of the variables).
        NOTE: Ignored if `variable_groups` is not None.
    variable_groups : dict of lists, optional
        Group the variables to calculate their joint variable importance
        e.g. `{'X': ['x1', 'x2'], 'Y': ['y1', 'y2']}` (default is `None`).
    keep_raw_permutations: bool, optional
        Save results for all permutation rounds (default is `True`).
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    loss_function : function
        Loss function used to assess the variable importance.
    type : {'variable_importance', 'ratio', 'difference'}
        Type of transformation that will be applied to dropout loss.
    N : int
        Number of observations that will be sampled from the `data` attribute 
        before the calculation of variable importance.
    B : int
        Number of permutation rounds to perform on each variable.
    variables : array_like of str or None
        Variables for which the importance will be calculated
    variable_groups : dict of lists or None
        Grouped variables to calculate their joint variable importance.
    keep_raw_permutations: bool
        Save the results for all permutation rounds.
    permutation : pd.DataFrame or None
        The results for all permutation rounds.
    processes : int
        Number of parallel processes to use in calculations. Iterated over `B`.
    random_state : int or None
        Set seed for random number generator.

    Notes
    --------
    - https://pbiecek.github.io/ema/featureImportance.html
    """

    def __init__(self,
                 loss_function='rmse',
                 type='variable_importance',
                 N=1000,
                 B=10,
                 variables=None,
                 variable_groups=None,
                 keep_raw_permutations=True,
                 processes=1,
                 random_state=None):

        _loss_function = checks.check_loss_function(loss_function)
        _B = checks.check_B(B)
        _type = checks.check_type(type)
        _random_state = checks.check_random_state(random_state)
        _keep_raw_permutations = checks.check_keep_raw_permutations(keep_raw_permutations, B)
        _processes = checks.check_processes(processes)

        self.loss_function = _loss_function
        self.type = _type
        self.N = N
        self.B = _B
        self.variables = variables
        self.variable_groups = variable_groups
        self.random_state = _random_state
        self.keep_raw_permutations = _keep_raw_permutations
        self.result = pd.DataFrame()
        self.permutation = None
        self.processes = _processes

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self, explainer):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.

        Returns
        -----------
        None
        """

        # if `variable_groups` are not specified, then extract from `variables`
        self.variable_groups = checks.check_variable_groups(self.variable_groups, explainer)
        self.variables = checks.check_variables(self.variables, self.variable_groups, explainer)
        self.result, self.permutation = utils.calculate_variable_importance(explainer,
                                                                            self.type,
                                                                            self.loss_function,
                                                                            self.variables,
                                                                            self.N,
                                                                            self.B,
                                                                            explainer.label,
                                                                            self.processes,
                                                                            self.keep_raw_permutations,
                                                                            self.random_state)

    def plot(self,
             objects=None,
             max_vars=10,
             digits=3,
             rounding_function=np.around,
             bar_width=16,
             split=("model", "variable"),
             title="Variable Importance",
             vertical_spacing=None,
             show=True):
        """Plot the Variable Importance explanation

        Parameters
        -----------
        objects : VariableImportance object or array_like of VariableImportance objects
            Additional objects to plot in subplots (default is `None`).
        max_vars : int, optional
            Maximum number of variables that will be presented for for each subplot
            (default is `10`).
        digits : int, optional
            Number of decimal places (`np.around`) to round contributions.
            See `rounding_function` parameter (default is `3`).
        rounding_function : function, optional
            A function that will be used for rounding numbers (default is `np.around`).
        bar_width : float, optional
            Width of bars in px (default is `16`).
        split : {'model', 'variable'}, optional
            Split the subplots by model or variable (default is `'model'`).
        title : str, optional
            Title of the plot (default is `"Variable Importance"`).
        vertical_spacing : float <0, 1>, optional
            Ratio of vertical space between the plots (default is `0.2/number of rows`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can 
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -----------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        if isinstance(split, tuple):
            split = split[0]

        if split not in ("model", "variable"):
            raise TypeError("split should be 'model' or 'variable'")

        # are there any other objects to plot?
        if objects is None:
            n = 1
            _result_df = self.result.copy()
            if split == 'variable':  # force split by model if only one explainer
                split = 'model'
        elif isinstance(objects, self.__class__):  # allow for objects to be a single element
            n = 2
            _result_df = pd.concat([self.result.copy(), objects.result.copy()])
        elif isinstance(objects, (list, tuple)):  # objects as tuple or array
            n = len(objects) + 1
            _result_df = self.result.copy()
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _result_df = pd.concat([_result_df, ob.result.copy()])
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        dl = _result_df.loc[_result_df.variable != '_baseline_', 'dropout_loss'].to_numpy()
        min_max_margin = dl.ptp() * 0.15
        min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

        # take out full model
        best_fits = _result_df[_result_df.variable == '_full_model_']

        # this produces dropout_loss_x and dropout_loss_y columns
        _result_df = _result_df.merge(best_fits[['label', 'dropout_loss']], how="left", on="label")
        _result_df = _result_df[['label', 'variable', 'dropout_loss_x', 'dropout_loss_y']].rename(
            columns={'dropout_loss_x': 'dropout_loss', 'dropout_loss_y': 'full_model'})

        # remove full_model and baseline
        _result_df = _result_df[(_result_df.variable != '_full_model_') & (_result_df.variable != '_baseline_')]

        # calculate order of bars or variable plots (split = 'variable')
        # get variable permutation
        perm = _result_df[['variable', 'dropout_loss']].groupby('variable').mean().reset_index(). \
            sort_values('dropout_loss', ascending=False).variable.values

        plot_height = 78 + 71

        colors = _theme.get_default_colors(n, 'bar')

        if vertical_spacing is None:
            vertical_spacing = 0.2 / n

        model_names = _result_df['label'].unique().tolist()

        if len(model_names) != n:
            raise ValueError('label must be unique for each model')

        if split == "model":
            # init plot
            fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing,
                                x_title='drop-out loss',
                                subplot_titles=model_names)

            # split df by model
            df_list = [v for k, v in _result_df.groupby('label', sort=False)]

            for i, df in enumerate(df_list):
                m = df.shape[0]
                if max_vars is not None and max_vars < m:
                    m = max_vars

                # take only m variables (for max_vars)
                # sort rows of df by variable permutation and drop unused variables
                df = df.sort_values('dropout_loss').tail(m) \
                    .set_index('variable').reindex(perm).dropna().reset_index()

                baseline = df.iloc[0, df.columns.get_loc('full_model')]

                df = df.assign(difference=lambda x: x['dropout_loss'] - baseline)

                lt = df.difference.apply(lambda val:
                                         "+"+str(rounding_function(np.abs(val), digits)) if val > 0
                                         else str(rounding_function(np.abs(val), digits)))
                tt = df.apply(lambda row: plot.tooltip_text(row, rounding_function, digits), axis=1)
                df = df.assign(label_text=lt,
                               tooltip_text=tt)

                fig.add_shape(type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x",
                              line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1)

                fig.add_bar(
                    orientation="h",
                    y=df['variable'].tolist(),
                    x=df['difference'].tolist(),
                    textposition="outside",
                    text=df['label_text'].tolist(),
                    marker_color=colors[i],
                    base=baseline,
                    hovertext=df['tooltip_text'].tolist(),
                    hoverinfo='text',
                    hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'},
                    showlegend=False,
                    row=i + 1, col=1
                )

                fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True,
                                  'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True},
                                 row=i + 1, col=1)

                fig.update_xaxes(
                    {'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
                     'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True},
                    row=i + 1, col=1)

                plot_height += m * bar_width + (m + 1) * bar_width / 4 + 30
        else:
            # split df by variable
            df_list = [v for k, v in _result_df.groupby('variable', sort=False)]

            n = len(df_list)
            if max_vars is not None and max_vars < n:
                n = max_vars
            
            if vertical_spacing is None:
                vertical_spacing = 0.2 / n
            
            # init plot
            variable_names = perm[0:n]
            fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='drop-out loss',
                                subplot_titles=variable_names)

            df_dict = {e.variable.array[0]: e for e in df_list}

            # take only n=max_vars elements from df_dict
            for i in range(n):
                df = df_dict[perm[i]]
                m = df.shape[0]

                baseline = 0

                df = df.assign(difference=lambda x: x['dropout_loss'] - x['full_model'])

                lt = df.difference.apply(lambda val:
                                         "+"+str(rounding_function(np.abs(val), digits)) if val > 0
                                         else str(rounding_function(np.abs(val), digits)))
                tt = df.apply(lambda row: plot.tooltip_text(row, rounding_function, digits), axis=1)
                df = df.assign(label_text=lt,
                               tooltip_text=tt)

                fig.add_shape(type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x",
                              line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1)

                fig.add_bar(
                    orientation="h",
                    y=df['label'].tolist(),
                    x=df['dropout_loss'].tolist(),
                    # textposition="outside",
                    # text=df['label_text'].tolist(),
                    marker_color=colors,
                    base=baseline,
                    hovertext=df['tooltip_text'].tolist(),
                    hoverinfo='text',
                    hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'},
                    showlegend=False,
                    row=i + 1, col=1)

                fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True,
                                  'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'dtick':1, 'fixedrange': True},
                                 row=i + 1, col=1)

                fig.update_xaxes(
                    {'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
                     'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True},
                    row=i + 1, col=1)

                plot_height += m * bar_width + (m + 1) * bar_width / 4

        plot_height += (n - 1) * 70

        fig.update_xaxes({'range': min_max})
        fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
                          height=plot_height, margin={'t': 78, 'b': 71, 'r': 30})

        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Subclasses

  • dalex.aspect._model_aspect_importance.object.ModelAspectImportance

Methods

def fit(self, explainer)

Calculate the result of explanation

Fit method makes calculations in place and changes the attributes.

Parameters

explainer : Explainer object
Model wrapper created using the Explainer class.

Returns

None
 
Expand source code Browse git
def fit(self, explainer):
    """Calculate the result of explanation

    Fit method makes calculations in place and changes the attributes.

    Parameters
    -----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.

    Returns
    -----------
    None
    """

    # if `variable_groups` are not specified, then extract from `variables`
    self.variable_groups = checks.check_variable_groups(self.variable_groups, explainer)
    self.variables = checks.check_variables(self.variables, self.variable_groups, explainer)
    self.result, self.permutation = utils.calculate_variable_importance(explainer,
                                                                        self.type,
                                                                        self.loss_function,
                                                                        self.variables,
                                                                        self.N,
                                                                        self.B,
                                                                        explainer.label,
                                                                        self.processes,
                                                                        self.keep_raw_permutations,
                                                                        self.random_state)
def plot(self, objects=None, max_vars=10, digits=3, rounding_function=<function around>, bar_width=16, split=('model', 'variable'), title='Variable Importance', vertical_spacing=None, show=True)

Plot the Variable Importance explanation

Parameters

objects : VariableImportance object or array_like of VariableImportance objects
Additional objects to plot in subplots (default is None).
max_vars : int, optional
Maximum number of variables that will be presented for for each subplot (default is 10).
digits : int, optional
Number of decimal places (np.around) to round contributions. See rounding_function parameter (default is 3).
rounding_function : function, optional
A function that will be used for rounding numbers (default is np.around).
bar_width : float, optional
Width of bars in px (default is 16).
split : {'model', 'variable'}, optional
Split the subplots by model or variable (default is 'model').
title : str, optional
Title of the plot (default is "Variable Importance").
vertical_spacing : float <0, 1>, optional
Ratio of vertical space between the plots (default is 0.2/number of rows).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(self,
         objects=None,
         max_vars=10,
         digits=3,
         rounding_function=np.around,
         bar_width=16,
         split=("model", "variable"),
         title="Variable Importance",
         vertical_spacing=None,
         show=True):
    """Plot the Variable Importance explanation

    Parameters
    -----------
    objects : VariableImportance object or array_like of VariableImportance objects
        Additional objects to plot in subplots (default is `None`).
    max_vars : int, optional
        Maximum number of variables that will be presented for for each subplot
        (default is `10`).
    digits : int, optional
        Number of decimal places (`np.around`) to round contributions.
        See `rounding_function` parameter (default is `3`).
    rounding_function : function, optional
        A function that will be used for rounding numbers (default is `np.around`).
    bar_width : float, optional
        Width of bars in px (default is `16`).
    split : {'model', 'variable'}, optional
        Split the subplots by model or variable (default is `'model'`).
    title : str, optional
        Title of the plot (default is `"Variable Importance"`).
    vertical_spacing : float <0, 1>, optional
        Ratio of vertical space between the plots (default is `0.2/number of rows`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can 
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -----------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    if isinstance(split, tuple):
        split = split[0]

    if split not in ("model", "variable"):
        raise TypeError("split should be 'model' or 'variable'")

    # are there any other objects to plot?
    if objects is None:
        n = 1
        _result_df = self.result.copy()
        if split == 'variable':  # force split by model if only one explainer
            split = 'model'
    elif isinstance(objects, self.__class__):  # allow for objects to be a single element
        n = 2
        _result_df = pd.concat([self.result.copy(), objects.result.copy()])
    elif isinstance(objects, (list, tuple)):  # objects as tuple or array
        n = len(objects) + 1
        _result_df = self.result.copy()
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _result_df = pd.concat([_result_df, ob.result.copy()])
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    dl = _result_df.loc[_result_df.variable != '_baseline_', 'dropout_loss'].to_numpy()
    min_max_margin = dl.ptp() * 0.15
    min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

    # take out full model
    best_fits = _result_df[_result_df.variable == '_full_model_']

    # this produces dropout_loss_x and dropout_loss_y columns
    _result_df = _result_df.merge(best_fits[['label', 'dropout_loss']], how="left", on="label")
    _result_df = _result_df[['label', 'variable', 'dropout_loss_x', 'dropout_loss_y']].rename(
        columns={'dropout_loss_x': 'dropout_loss', 'dropout_loss_y': 'full_model'})

    # remove full_model and baseline
    _result_df = _result_df[(_result_df.variable != '_full_model_') & (_result_df.variable != '_baseline_')]

    # calculate order of bars or variable plots (split = 'variable')
    # get variable permutation
    perm = _result_df[['variable', 'dropout_loss']].groupby('variable').mean().reset_index(). \
        sort_values('dropout_loss', ascending=False).variable.values

    plot_height = 78 + 71

    colors = _theme.get_default_colors(n, 'bar')

    if vertical_spacing is None:
        vertical_spacing = 0.2 / n

    model_names = _result_df['label'].unique().tolist()

    if len(model_names) != n:
        raise ValueError('label must be unique for each model')

    if split == "model":
        # init plot
        fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing,
                            x_title='drop-out loss',
                            subplot_titles=model_names)

        # split df by model
        df_list = [v for k, v in _result_df.groupby('label', sort=False)]

        for i, df in enumerate(df_list):
            m = df.shape[0]
            if max_vars is not None and max_vars < m:
                m = max_vars

            # take only m variables (for max_vars)
            # sort rows of df by variable permutation and drop unused variables
            df = df.sort_values('dropout_loss').tail(m) \
                .set_index('variable').reindex(perm).dropna().reset_index()

            baseline = df.iloc[0, df.columns.get_loc('full_model')]

            df = df.assign(difference=lambda x: x['dropout_loss'] - baseline)

            lt = df.difference.apply(lambda val:
                                     "+"+str(rounding_function(np.abs(val), digits)) if val > 0
                                     else str(rounding_function(np.abs(val), digits)))
            tt = df.apply(lambda row: plot.tooltip_text(row, rounding_function, digits), axis=1)
            df = df.assign(label_text=lt,
                           tooltip_text=tt)

            fig.add_shape(type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x",
                          line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1)

            fig.add_bar(
                orientation="h",
                y=df['variable'].tolist(),
                x=df['difference'].tolist(),
                textposition="outside",
                text=df['label_text'].tolist(),
                marker_color=colors[i],
                base=baseline,
                hovertext=df['tooltip_text'].tolist(),
                hoverinfo='text',
                hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'},
                showlegend=False,
                row=i + 1, col=1
            )

            fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True,
                              'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True},
                             row=i + 1, col=1)

            fig.update_xaxes(
                {'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
                 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True},
                row=i + 1, col=1)

            plot_height += m * bar_width + (m + 1) * bar_width / 4 + 30
    else:
        # split df by variable
        df_list = [v for k, v in _result_df.groupby('variable', sort=False)]

        n = len(df_list)
        if max_vars is not None and max_vars < n:
            n = max_vars
        
        if vertical_spacing is None:
            vertical_spacing = 0.2 / n
        
        # init plot
        variable_names = perm[0:n]
        fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='drop-out loss',
                            subplot_titles=variable_names)

        df_dict = {e.variable.array[0]: e for e in df_list}

        # take only n=max_vars elements from df_dict
        for i in range(n):
            df = df_dict[perm[i]]
            m = df.shape[0]

            baseline = 0

            df = df.assign(difference=lambda x: x['dropout_loss'] - x['full_model'])

            lt = df.difference.apply(lambda val:
                                     "+"+str(rounding_function(np.abs(val), digits)) if val > 0
                                     else str(rounding_function(np.abs(val), digits)))
            tt = df.apply(lambda row: plot.tooltip_text(row, rounding_function, digits), axis=1)
            df = df.assign(label_text=lt,
                           tooltip_text=tt)

            fig.add_shape(type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x",
                          line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1)

            fig.add_bar(
                orientation="h",
                y=df['label'].tolist(),
                x=df['dropout_loss'].tolist(),
                # textposition="outside",
                # text=df['label_text'].tolist(),
                marker_color=colors,
                base=baseline,
                hovertext=df['tooltip_text'].tolist(),
                hoverinfo='text',
                hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'},
                showlegend=False,
                row=i + 1, col=1)

            fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True,
                              'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'dtick':1, 'fixedrange': True},
                             row=i + 1, col=1)

            fig.update_xaxes(
                {'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside",
                 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True},
                row=i + 1, col=1)

            plot_height += m * bar_width + (m + 1) * bar_width / 4

    plot_height += (n - 1) * 70

    fig.update_xaxes({'range': min_max})
    fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
                      height=plot_height, margin={'t': 78, 'b': 71, 'r': 30})

    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig