Module dalex.predict_explanations

Expand source code Browse git
from ._break_down.object import BreakDown
from ._ceteris_paribus.object import CeterisParibus
from ._shap.object import Shap

__all__ = [
    "BreakDown",
    "CeterisParibus",
    "Shap"
]

Classes

class BreakDown (type='break_down', order=None, interaction_preference=1, keep_distributions=False)

Calculate predict-level variable attributions as Break Down

Parameters

type : {'break_down_interactions', 'break_down'}
Type of variable attributions (default is 'break_down_interactions').
order : list of int or str, optional
Use a fixed order of variables for attribution calculation. Use integer values or string variable names (default is None which means order by importance).
interaction_preference : int, optional
Specify which interactions will be present in an explanation. The larger the integer, the more frequently interactions will be presented (default is 1).
keep_distributions : bool, optional
Save the distribution of partial predictions (default is False).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
type : {'break_down_interactions', 'break_down'}
Type of variable attributions.
order : list of int or str or None
Order of variables used in attribution calculation.
interaction_preference : int
Frequency of interaction use.
keep_distributions : bool
Save the distribution of partial predictions.
yhats_distributions : pd.DataFrame or None
The distribution of partial predictions.

Notes

Expand source code Browse git
class BreakDown(Explanation):
    """Calculate predict-level variable attributions as Break Down

    Parameters
    -----------
    type : {'break_down_interactions', 'break_down'}
        Type of variable attributions (default is `'break_down_interactions'`).
    order : list of int or str, optional
        Use a fixed order of variables for attribution calculation. Use integer values
        or string variable names (default is `None` which means order by importance).
    interaction_preference : int, optional
        Specify which interactions will be present in an explanation. The larger the
        integer, the more frequently interactions will be presented (default is `1`).
    keep_distributions : bool, optional
        Save the distribution of partial predictions (default is `False`).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    type : {'break_down_interactions', 'break_down'}
        Type of variable attributions.
    order : list of int or str or None
        Order of variables used in attribution calculation.
    interaction_preference : int
        Frequency of interaction use.
    keep_distributions : bool
        Save the distribution of partial predictions.
    yhats_distributions : pd.DataFrame or None
        The distribution of partial predictions.

    Notes
    --------
    - https://pbiecek.github.io/ema/breakDown.html
    - https://pbiecek.github.io/ema/iBreakDown.html
    """

    def __init__(self,
                 type='break_down',
                 order=None,
                 interaction_preference=1,
                 keep_distributions=False):

        _order = checks.check_order(order)

        self.type = type
        self.keep_distributions = keep_distributions
        self.order = _order
        self.interaction_preference = interaction_preference
        self.result = pd.DataFrame()
        self.yhats_distributions = None

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self,
            explainer,
            new_observation):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.
        new_observation : pd.Series or np.ndarray
            An observation for which a prediction needs to be explained.

        Returns
        -----------
        None
        """

        _new_observation = checks.check_new_observation(new_observation, explainer)
        if _new_observation.shape[0] != 1:
            warnings.warn("You should pass only one new_observation, taken only first")
            _new_observation = _new_observation.iloc[0, :]

        if self.type == 'break_down_interactions':
            result, yhats_distributions = utils.local_interactions(
                explainer,
                _new_observation,
                self.interaction_preference,
                '2d',
                self.order,
                self.keep_distributions
            )
        elif self.type == 'break_down':
            result, yhats_distributions = utils.local_interactions(
                explainer,
                _new_observation,
                self.interaction_preference,
                '1d',
                self.order,
                self.keep_distributions
            )
        else:
            raise ValueError("'type' must be one of {'break_down_interactions', 'break_down'}")

        self.result = result
        self.yhats_distributions = yhats_distributions

    def plot(self,
             objects=None,
             baseline=None,
             max_vars=10,
             digits=3,
             rounding_function=np.around,
             bar_width=16,
             min_max=None,
             vcolors=None,
             title="Break Down",
             vertical_spacing=None,
             show=True):
        """Plot the Break Down explanation

        Parameters
        -----------
        objects : BreakDown object or array_like of BreakDown objects
            Additional objects to plot in subplots (default is `None`).
        baseline: float, optional
            Starting x point for bars (default is average prediction).
        max_vars : int, optional
            Maximum number of variables that will be presented for for each subplot
            (default is `10`).
        digits : int, optional
            Number of decimal places (`np.around`) to round contributions.
            See `rounding_function` parameter (default is `3`).
        rounding_function : function, optional
            A function that will be used for rounding numbers (default is `np.around`).
        bar_width : float, optional
            Width of bars in px (default is `16`).
        min_max : 2-tuple of float, optional
            Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`).
        vcolors : 3-tuple of str, optional
            Color of bars (default is `["#371ea3", "#8bdcbe", "#f05a71"]`).
        title : str, optional
            Title of the plot (default is `"Break Down"`).
        vertical_spacing : float <0, 1>, optional
            Ratio of vertical space between the plots (default is `0.2/number of rows`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -----------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        # are there any other objects to plot?
        if objects is None:
            n = 1
            _result_list = [self.result.copy()]
        elif isinstance(objects, self.__class__):  # allow for objects to be a single element
            n = 2
            _result_list = [self.result.copy(), objects.result.copy()]
        elif isinstance(objects, (list, tuple)):  # objects as tuple or array
            n = len(objects) + 1
            _result_list = [self.result.copy()]
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _result_list += [ob.result.copy()]
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        deleted_indexes = []
        for i, _result in enumerate(_result_list):
            if len(_result['label'].unique()) > 1:
                n += len(_result['label'].unique()) - 1
                # add new data frames to list
                _result_list += [v for k, v in _result.groupby('label', sort=False)]

                deleted_indexes += [i]

        _result_list = [val for i, val in enumerate(_result_list) if i not in deleted_indexes]
        model_names = [result.iloc[0, result.columns.get_loc("label")] for result in _result_list]

        if vertical_spacing is None:
            vertical_spacing = 0.2 / n

        fig = make_subplots(rows=n, cols=1,
                            shared_xaxes=True, vertical_spacing=vertical_spacing,
                            x_title='contribution', subplot_titles=model_names)
        plot_height = 78 + 71

        if vcolors is None:
            vcolors = _theme.get_break_down_colors()

        if min_max is None:
            temp_min_max = [np.Inf, -np.Inf]
        else:
            temp_min_max = min_max

        for i, _result in enumerate(_result_list):
            if _result.shape[0] - 2 <= max_vars:
                m = _result.shape[0]
            else:
                m = max_vars + 3

            if baseline is None:
                baseline = _result.iloc[0, _result.columns.get_loc("cumulative")]

            df = plot.prepare_data_for_break_down_plot(_result, baseline, max_vars, rounding_function, digits)

            measure = ["relative"] * m
            measure[m - 1] = "total"

            fig.add_shape(
                type='line',
                x0=baseline,
                x1=baseline,
                y0=-1,
                y1=m,
                yref="paper",
                xref="x",
                line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'},
                row=i + 1, col=1
            )

            fig.add_waterfall(
                orientation="h",
                measure=measure,
                y=df['variable'].tolist(),
                x=df['contribution'].tolist(),
                textposition="outside",
                text=df['label_text'].tolist(),
                connector={"mode": "spanning", "line": {"width": 1, "color": "#371ea3", "dash": "solid"}},
                decreasing={"marker": {"color": vcolors[-1]}},
                increasing={"marker": {"color": vcolors[1]}},
                totals={"marker": {"color": vcolors[0]}},
                base=baseline,
                hovertext=df['tooltip_text'].tolist(),
                hoverinfo='text+delta',
                hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'},
                showlegend=False,
                row=i + 1, col=1
            )

            fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True,
                              'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True},
                             row=i + 1, col=1)

            if min_max is None:
                cum = df.cumulative.values
                min_max_margin = cum.ptp() * 0.15
                temp_min_max[0] = np.min([temp_min_max[0], cum.min() - min_max_margin])
                temp_min_max[1] = np.max([temp_min_max[1], cum.max() + min_max_margin])

            fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                              'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True},
                             row=i + 1, col=1)

            plot_height += m * bar_width + (m + 1) * bar_width / 4

        plot_height += (n - 1) * 70

        fig.update_xaxes({'range': temp_min_max})
        fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
                          height=plot_height, margin={'t': 78, 'b': 71, 'r': 30})

        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, explainer, new_observation)

Calculate the result of explanation

Fit method makes calculations in place and changes the attributes.

Parameters

explainer : Explainer object
Model wrapper created using the Explainer class.
new_observation : pd.Series or np.ndarray
An observation for which a prediction needs to be explained.

Returns

None
 
Expand source code Browse git
def fit(self,
        explainer,
        new_observation):
    """Calculate the result of explanation

    Fit method makes calculations in place and changes the attributes.

    Parameters
    -----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.
    new_observation : pd.Series or np.ndarray
        An observation for which a prediction needs to be explained.

    Returns
    -----------
    None
    """

    _new_observation = checks.check_new_observation(new_observation, explainer)
    if _new_observation.shape[0] != 1:
        warnings.warn("You should pass only one new_observation, taken only first")
        _new_observation = _new_observation.iloc[0, :]

    if self.type == 'break_down_interactions':
        result, yhats_distributions = utils.local_interactions(
            explainer,
            _new_observation,
            self.interaction_preference,
            '2d',
            self.order,
            self.keep_distributions
        )
    elif self.type == 'break_down':
        result, yhats_distributions = utils.local_interactions(
            explainer,
            _new_observation,
            self.interaction_preference,
            '1d',
            self.order,
            self.keep_distributions
        )
    else:
        raise ValueError("'type' must be one of {'break_down_interactions', 'break_down'}")

    self.result = result
    self.yhats_distributions = yhats_distributions
def plot(self, objects=None, baseline=None, max_vars=10, digits=3, rounding_function=<function around>, bar_width=16, min_max=None, vcolors=None, title='Break Down', vertical_spacing=None, show=True)

Plot the Break Down explanation

Parameters

objects : BreakDown object or array_like of BreakDown objects
Additional objects to plot in subplots (default is None).
baseline : float, optional
Starting x point for bars (default is average prediction).
max_vars : int, optional
Maximum number of variables that will be presented for for each subplot (default is 10).
digits : int, optional
Number of decimal places (np.around) to round contributions. See rounding_function parameter (default is 3).
rounding_function : function, optional
A function that will be used for rounding numbers (default is np.around).
bar_width : float, optional
Width of bars in px (default is 16).
min_max : 2-tuple of float, optional
Range of OX axis (default is [min-0.15*(max-min), max+0.15*(max-min)]).
vcolors : 3-tuple of str, optional
Color of bars (default is ["#371ea3", "#8bdcbe", "#f05a71"]).
title : str, optional
Title of the plot (default is "Break Down").
vertical_spacing : float <0, 1>, optional
Ratio of vertical space between the plots (default is 0.2/number of rows).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(self,
         objects=None,
         baseline=None,
         max_vars=10,
         digits=3,
         rounding_function=np.around,
         bar_width=16,
         min_max=None,
         vcolors=None,
         title="Break Down",
         vertical_spacing=None,
         show=True):
    """Plot the Break Down explanation

    Parameters
    -----------
    objects : BreakDown object or array_like of BreakDown objects
        Additional objects to plot in subplots (default is `None`).
    baseline: float, optional
        Starting x point for bars (default is average prediction).
    max_vars : int, optional
        Maximum number of variables that will be presented for for each subplot
        (default is `10`).
    digits : int, optional
        Number of decimal places (`np.around`) to round contributions.
        See `rounding_function` parameter (default is `3`).
    rounding_function : function, optional
        A function that will be used for rounding numbers (default is `np.around`).
    bar_width : float, optional
        Width of bars in px (default is `16`).
    min_max : 2-tuple of float, optional
        Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`).
    vcolors : 3-tuple of str, optional
        Color of bars (default is `["#371ea3", "#8bdcbe", "#f05a71"]`).
    title : str, optional
        Title of the plot (default is `"Break Down"`).
    vertical_spacing : float <0, 1>, optional
        Ratio of vertical space between the plots (default is `0.2/number of rows`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -----------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    # are there any other objects to plot?
    if objects is None:
        n = 1
        _result_list = [self.result.copy()]
    elif isinstance(objects, self.__class__):  # allow for objects to be a single element
        n = 2
        _result_list = [self.result.copy(), objects.result.copy()]
    elif isinstance(objects, (list, tuple)):  # objects as tuple or array
        n = len(objects) + 1
        _result_list = [self.result.copy()]
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _result_list += [ob.result.copy()]
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    deleted_indexes = []
    for i, _result in enumerate(_result_list):
        if len(_result['label'].unique()) > 1:
            n += len(_result['label'].unique()) - 1
            # add new data frames to list
            _result_list += [v for k, v in _result.groupby('label', sort=False)]

            deleted_indexes += [i]

    _result_list = [val for i, val in enumerate(_result_list) if i not in deleted_indexes]
    model_names = [result.iloc[0, result.columns.get_loc("label")] for result in _result_list]

    if vertical_spacing is None:
        vertical_spacing = 0.2 / n

    fig = make_subplots(rows=n, cols=1,
                        shared_xaxes=True, vertical_spacing=vertical_spacing,
                        x_title='contribution', subplot_titles=model_names)
    plot_height = 78 + 71

    if vcolors is None:
        vcolors = _theme.get_break_down_colors()

    if min_max is None:
        temp_min_max = [np.Inf, -np.Inf]
    else:
        temp_min_max = min_max

    for i, _result in enumerate(_result_list):
        if _result.shape[0] - 2 <= max_vars:
            m = _result.shape[0]
        else:
            m = max_vars + 3

        if baseline is None:
            baseline = _result.iloc[0, _result.columns.get_loc("cumulative")]

        df = plot.prepare_data_for_break_down_plot(_result, baseline, max_vars, rounding_function, digits)

        measure = ["relative"] * m
        measure[m - 1] = "total"

        fig.add_shape(
            type='line',
            x0=baseline,
            x1=baseline,
            y0=-1,
            y1=m,
            yref="paper",
            xref="x",
            line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'},
            row=i + 1, col=1
        )

        fig.add_waterfall(
            orientation="h",
            measure=measure,
            y=df['variable'].tolist(),
            x=df['contribution'].tolist(),
            textposition="outside",
            text=df['label_text'].tolist(),
            connector={"mode": "spanning", "line": {"width": 1, "color": "#371ea3", "dash": "solid"}},
            decreasing={"marker": {"color": vcolors[-1]}},
            increasing={"marker": {"color": vcolors[1]}},
            totals={"marker": {"color": vcolors[0]}},
            base=baseline,
            hovertext=df['tooltip_text'].tolist(),
            hoverinfo='text+delta',
            hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'},
            showlegend=False,
            row=i + 1, col=1
        )

        fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True,
                          'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True},
                         row=i + 1, col=1)

        if min_max is None:
            cum = df.cumulative.values
            min_max_margin = cum.ptp() * 0.15
            temp_min_max[0] = np.min([temp_min_max[0], cum.min() - min_max_margin])
            temp_min_max[1] = np.max([temp_min_max[1], cum.max() + min_max_margin])

        fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                          'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True},
                         row=i + 1, col=1)

        plot_height += m * bar_width + (m + 1) * bar_width / 4

    plot_height += (n - 1) * 70

    fig.update_xaxes({'range': temp_min_max})
    fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
                      height=plot_height, margin={'t': 78, 'b': 71, 'r': 30})

    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
class CeterisParibus (variables=None, grid_points=101, variable_splits=None, variable_splits_type='uniform', variable_splits_with_obs=False, processes=1)

Calculate predict-level variable profiles as Ceteris Paribus

Parameters

variables : array_like of str, optional
Variables for which the profiles will be calculated (default is None, which means all of the variables).
grid_points : int, optional
Maximum number of points for profile calculations (default is 101). NOTE: The final number of points may be lower than grid_points, eg. if there is not enough unique values for a given variable.
variable_splits : dict of lists, optional
Split points for variables e.g. {'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']} (default is None, which means that they will be calculated using one of variable_splits_type and the data attribute).
variable_splits_type : {'uniform', 'quantiles'}, optional
Way of calculating variable_splits. Set 'quantiles' for percentiles. (default is 'uniform', which means uniform grid of points).
variable_splits_with_obs : bool, optional
Add variable values of new_observation data to the variable_splits (default is True).
processes : int, optional
Number of parallel processes to use in calculations. Iterated over variables (default is 1, which means no parallel computation).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
new_observation : pd.DataFrame
Observations for which predictions need to be explained.
variables : array_like of str or None
Variables for which the profiles will be calculated.
grid_points : int
Maximum number of points for profile calculations.
variable_splits : dict of lists or None
Split points for variables.
variable_splits_type : {'uniform', 'quantiles'}
Way of calculating variable_splits.
variable_splits_with_obs : bool
Add variable values of new_observation data to the variable_splits.
processes : int
Number of parallel processes to use in calculations. Iterated over B.

Notes

Expand source code Browse git
class CeterisParibus(Explanation):
    """Calculate predict-level variable profiles as Ceteris Paribus

    Parameters
    -----------
    variables : array_like of str, optional
        Variables for which the profiles will be calculated
        (default is `None`, which means all of the variables).
    grid_points : int, optional
        Maximum number of points for profile calculations (default is `101`).
        NOTE: The final number of points may be lower than `grid_points`,
        eg. if there is not enough unique values for a given variable.
    variable_splits : dict of lists, optional
        Split points for variables e.g. `{'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']}`
        (default is `None`, which means that they will be calculated using one of
        `variable_splits_type` and the `data` attribute).
    variable_splits_type : {'uniform', 'quantiles'}, optional
        Way of calculating `variable_splits`. Set `'quantiles'` for percentiles.
        (default is `'uniform'`, which means uniform grid of points).
    variable_splits_with_obs: bool, optional
        Add variable values of `new_observation` data to the `variable_splits`
        (default is `True`).
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `variables`
        (default is `1`, which means no parallel computation).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    new_observation : pd.DataFrame
        Observations for which predictions need to be explained.
    variables : array_like of str or None
        Variables for which the profiles will be calculated.
    grid_points : int
        Maximum number of points for profile calculations.
    variable_splits : dict of lists or None
        Split points for variables.
    variable_splits_type : {'uniform', 'quantiles'}
        Way of calculating `variable_splits`.
    variable_splits_with_obs: bool
        Add variable values of `new_observation` data to the `variable_splits`.
    processes : int
        Number of parallel processes to use in calculations. Iterated over `B`.

    Notes
    --------
    - https://pbiecek.github.io/ema/ceterisParibus.html
    """

    def __init__(self,
                 variables=None,
                 grid_points=101,
                 variable_splits=None,
                 variable_splits_type='uniform',
                 variable_splits_with_obs=False,
                 processes=1):

        _processes = checks.check_processes(processes)
        _variable_splits_type = checks.check_variable_splits_type(variable_splits_type)

        self.variables = variables
        self.grid_points = grid_points
        self.variable_splits = variable_splits
        self.variable_splits_type = _variable_splits_type
        self.variable_splits_with_obs = variable_splits_with_obs
        self.result = pd.DataFrame()
        self.new_observation = None
        self.processes = _processes

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self,
            explainer,
            new_observation,
            y=None,
            verbose=True):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.
        new_observation : pd.DataFrame or np.ndarray
            Observations for which predictions need to be explained.
        y : pd.Series or np.ndarray (1d), optional
            Target variable with the same length as `new_observation`.
        verbose : bool, optional
            Print tqdm progress bar (default is `True`).

        Returns
        -----------
        None
        """

        self.variables = checks.check_variables(self.variables, explainer, self.variable_splits)

        checks.check_data(explainer.data, self.variables)

        self.new_observation = checks.check_new_observation(new_observation, explainer)

        self.variable_splits = checks.check_variable_splits(self.variable_splits,
                                                     self.variables,
                                                     self.grid_points,
                                                     explainer.data,
                                                     self.variable_splits_type,
                                                     self.variable_splits_with_obs,
                                                     self.new_observation)

        y = checks.check_y(y)

        self.result, self.new_observation = utils.calculate_ceteris_paribus(
            explainer,
            self.new_observation,
            self.variable_splits,
            y,
            self.processes,
            verbose
        )

    def plot(self,
             objects=None,
             variable_type="numerical",
             variables=None,
             size=2,
             alpha=1,
             color="_label_",
             facet_ncol=2,
             show_observations=True,
             title="Ceteris Paribus Profiles",
             y_title='prediction',
             horizontal_spacing=None,
             vertical_spacing=None,
             show=True):
        """Plot the Ceteris Paribus explanation

        Parameters
        -----------
        objects : CeterisParibus object or array_like of CeterisParibus objects
            Additional objects to plot in subplots (default is `None`).
        variable_type : {'numerical', 'categorical'}
            Plot the profiles for numerical or categorical variables 
            (default is `'numerical'`).
        variables : str or array_like of str, optional
            Variables for which the profiles will be calculated
            (default is `None`, which means all of the variables).
        size : float, optional
            Width of lines in px (default is `2`).
        alpha : float <0, 1>, optional
            Opacity of lines (default is `1`).
        color : str, optional
            Variable name used for grouping
            (default is `'_label_'`, which groups by models).
        facet_ncol : int, optional
            Number of columns on the plot grid (default is `2`).
        show_observations : bool, optional
            Show observation points (default is `True`).
        title : str, optional
            Title of the plot (default is `"Ceteris Paribus Profiles"`).
        y_title : str, optional
            Title of the y/x axis (default is `"prediction"`).
        horizontal_spacing : float <0, 1>, optional
            Ratio of horizontal space between the plots
            (default depends on `variable_type`).
        vertical_spacing : float <0, 1>, optional
            Ratio of vertical space between the plots (default is `0.3/number of rows`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can 
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -----------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        if variable_type not in ("numerical", "categorical"):
            raise TypeError("variable_type should be 'numerical' or 'categorical'")
        if isinstance(variables, str):
            variables = (variables,)

        # are there any other objects to plot?
        if objects is None:
            _result_df = self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_'])
            _include = self.variable_splits_with_obs
        elif isinstance(objects, self.__class__):  # allow for objects to be a single element
            _result_df = pd.concat([
                self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_']),
                objects.result.assign(_original_yhat_=lambda x: objects.new_observation.loc[x.index, '_yhat_'])])
            _include = np.all([self.variable_splits_with_obs, objects.variable_splits_with_obs])
        elif isinstance(objects, (list, tuple)):  # objects as tuple or array
            _result_df = self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_'])
            _include = [self.variable_splits_with_obs]
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _result_df = pd.concat([
                    _result_df, ob.result.assign(_original_yhat_=lambda x: ob.new_observation.loc[x.index, '_yhat_'])])
                _include += [ob.variable_splits_with_obs]
            _include = np.all(_include)
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        if _include is False and show_observations:
                warnings.warn("'show_observations' parameter changed to False,"
                              "because the 'variable_splits_with_obs' attribute is False"
                              "See `variable_splits_with_obs` parameter in `predict_profile`.")
                show_observations = False

        # variables to use
        all_variables = list(_result_df['_vname_'].dropna().unique())

        if variables is not None:
            all_variables = _global_utils.intersect_unsorted(variables, all_variables)
            if len(all_variables) == 0:
                raise TypeError("variables do not overlap with " + ''.join(variables))

        # names of numeric variables
        numeric_variables = _result_df[all_variables].select_dtypes(include=np.number).columns.tolist()

        if variable_type == "numerical":
            variable_names = numeric_variables

            if len(variable_names) == 0:
                # change to categorical
                variable_type = "categorical"
                # send message
                warnings.warn("'variable_type' parameter changed to 'categorical' due to lack of numerical variables.")
                # take all
                variable_names = all_variables
            elif variables is not None and len(variable_names) != len(variables):
                raise TypeError("There are no numerical variables")
        else:
            variable_names = np.setdiff1d(all_variables, numeric_variables).tolist()

            # there are variables selected
            if variables is not None:
                # take all
                variable_names = all_variables
            elif len(variable_names) == 0:
                # there were no variables selected and there are no categorical variables
                raise TypeError("There are no non-numerical variables.")

        # prepare profiles data
        _result_df = _result_df.loc[_result_df['_vname_'].isin(variable_names), ].reset_index(drop=True)

        #  calculate y axis range to allow for fixedrange True
        dl = _result_df['_yhat_'].to_numpy()
        min_max_margin = dl.ptp() * 0.10
        min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

        # create _x_
        if len(variable_names) == 1:
            _result_df.loc[:, '_x_'] = deepcopy(_result_df.loc[:, variable_names[0]])
        else:
            for variable in variable_names:
                where_variable = _result_df['_vname_'] == variable
                _result_df.loc[where_variable, '_x_'] = deepcopy(_result_df.loc[where_variable, variable])

        # change x column to proper character values
        if variable_type == 'categorical':
            _result_df.loc[:, '_x_'] = _result_df.apply(lambda row: str(row[row['_vname_']]), axis=1)

        n = len(variable_names)
        facet_nrow = int(np.ceil(n / facet_ncol))
        if vertical_spacing is None:
            vertical_spacing = 0.3 / facet_nrow if variable_type == 'numerical' else 0.05
        if horizontal_spacing is None:
            horizontal_spacing = 0.05 if variable_type == 'numerical' else 0.1

        plot_height = 78 + 71 + facet_nrow * (280 + 60)

        _result_df = _result_df.assign(_text_=_result_df.apply(lambda obs: plot.tooltip_text(obs), axis=1))

        if variable_type == "numerical":    
            m = len(_result_df[color].dropna().unique())
            _result_df[color] = _result_df[color].astype(object)  # prevent error when using pd.StringDtype
        
            fig = px.line(_result_df,
                          x="_x_", y="_yhat_", color=color, facet_col="_vname_", line_group='_ids_',
                          category_orders={"_vname_": list(variable_names)},
                          labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},  # , color: 'group'},
                          # hover_data={'_text_': True, '_yhat_': ':.3f', '_vname_': False, '_x_': False, color: False},
                          custom_data=['_text_'],
                          facet_col_wrap=facet_ncol,
                          facet_row_spacing=vertical_spacing,
                          facet_col_spacing=horizontal_spacing,
                          template="none",
                          color_discrete_sequence=_theme.get_default_colors(m, 'line')) \
                    .update_traces(dict(line_width=size, opacity=alpha,
                                        hovertemplate="%{customdata[0]}<extra></extra>")) \
                    .update_xaxes({'matches': None, 'showticklabels': True,
                                   'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                                   'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}) \
                    .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                                   'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
                                   'range': min_max})

            if show_observations:
                _points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy()

                fig_points = px.scatter(_points_df,
                                        x='_original_', y='_yhat_', facet_col='_vname_',
                                        category_orders={"_vname_": list(variable_names)},
                                        labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},
                                        custom_data=['_text_'],
                                        facet_col_wrap=facet_ncol,
                                        facet_row_spacing=vertical_spacing,
                                        facet_col_spacing=horizontal_spacing,
                                        color_discrete_sequence=["#371ea3"]) \
                               .update_traces(dict(marker_size=5*size, opacity=alpha),
                                              hovertemplate="%{customdata[0]}<extra></extra>")

                for _, value in enumerate(fig_points.data):
                    fig.add_trace(value)
                    
            fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, 'closest')

        else:
            if color=="_label_" and len(_result_df['_ids_'].unique()) > 1 and len(_result_df['_label_'].unique()) == 1:
                warnings.warn("'color' parameter changed to '_ids_' because there are multiple observations for one model.")
                color = '_ids_'
            elif color=="_label_" and len(_result_df['_ids_'].unique()) > len(_result_df['_label_'].unique()): 
                # https://github.com/plotly/plotly.py/issues/2657
                raise TypeError("Please pick one observation per label or change the `color` parameter.")

            m = len(_result_df[color].dropna().unique())
            _result_df[color] = _result_df[color].astype(object)  # prevent error when using pd.StringDtype
            
            _result_df = _result_df.assign(_diff_=lambda x: x['_yhat_'] - x['_original_yhat_'])
            fig = px.bar(_result_df,
                         x="_diff_", y="_x_", color=color, facet_col="_vname_",
                         category_orders={"_vname_": list(variable_names)},
                         labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},  # , color: 'group'},
                         # hover_data={'_yhat_': ':.3f', '_ids_': True, '_vname_': False, color: False},
                         custom_data=['_text_'],
                         base="_original_yhat_",
                         facet_col_wrap=facet_ncol,
                         facet_row_spacing=vertical_spacing,
                         facet_col_spacing=horizontal_spacing,
                         template="none",
                         color_discrete_sequence=_theme.get_default_colors(m, 'line'),
                         barmode='group',
                         orientation='h')  \
                    .update_traces(dict(opacity=alpha),
                                   hovertemplate="%{customdata[0]}<extra></extra>") \
                    .update_yaxes({'matches': None, 'showticklabels': True,
                                   'type': 'category', 'gridwidth': 2, 'automargin': True, 
                                   'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}) \
                    .update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                                   'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
                                   'range': min_max})
                    
            for _, bar in enumerate(fig.data):
                fig.add_vline(x=bar.base[0], layer='below',
                              line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'})
                
            if show_observations:
                _points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy()

                fig_points = px.scatter(_points_df,
                                        x='_yhat_', y='_x_', facet_col='_vname_',
                                        category_orders={"_vname_": list(variable_names)},
                                        labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},
                                        custom_data=['_text_'],
                                        facet_col_wrap=facet_ncol,
                                        facet_row_spacing=vertical_spacing,
                                        facet_col_spacing=horizontal_spacing,
                                        color_discrete_sequence=["#371ea3"]) \
                               .update_traces(dict(marker_size=5*size, opacity=alpha),
                                              hovertemplate="%{customdata[0]}<extra></extra>")

                for _, value in enumerate(fig_points.data):
                    fig.add_trace(value)

            fig = _theme.fig_update_bar_plot(fig, title, y_title, plot_height, 'closest')
            
        fig.update_layout(hoverlabel=dict(bgcolor='rgba(0,0,0,0.8)'))
        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, explainer, new_observation, y=None, verbose=True)

Calculate the result of explanation

Fit method makes calculations in place and changes the attributes.

Parameters

explainer : Explainer object
Model wrapper created using the Explainer class.
new_observation : pd.DataFrame or np.ndarray
Observations for which predictions need to be explained.
y : pd.Series or np.ndarray (1d), optional
Target variable with the same length as new_observation.
verbose : bool, optional
Print tqdm progress bar (default is True).

Returns

None
 
Expand source code Browse git
def fit(self,
        explainer,
        new_observation,
        y=None,
        verbose=True):
    """Calculate the result of explanation

    Fit method makes calculations in place and changes the attributes.

    Parameters
    -----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.
    new_observation : pd.DataFrame or np.ndarray
        Observations for which predictions need to be explained.
    y : pd.Series or np.ndarray (1d), optional
        Target variable with the same length as `new_observation`.
    verbose : bool, optional
        Print tqdm progress bar (default is `True`).

    Returns
    -----------
    None
    """

    self.variables = checks.check_variables(self.variables, explainer, self.variable_splits)

    checks.check_data(explainer.data, self.variables)

    self.new_observation = checks.check_new_observation(new_observation, explainer)

    self.variable_splits = checks.check_variable_splits(self.variable_splits,
                                                 self.variables,
                                                 self.grid_points,
                                                 explainer.data,
                                                 self.variable_splits_type,
                                                 self.variable_splits_with_obs,
                                                 self.new_observation)

    y = checks.check_y(y)

    self.result, self.new_observation = utils.calculate_ceteris_paribus(
        explainer,
        self.new_observation,
        self.variable_splits,
        y,
        self.processes,
        verbose
    )
def plot(self, objects=None, variable_type='numerical', variables=None, size=2, alpha=1, color='_label_', facet_ncol=2, show_observations=True, title='Ceteris Paribus Profiles', y_title='prediction', horizontal_spacing=None, vertical_spacing=None, show=True)

Plot the Ceteris Paribus explanation

Parameters

objects : CeterisParibus object or array_like of CeterisParibus objects
Additional objects to plot in subplots (default is None).
variable_type : {'numerical', 'categorical'}
Plot the profiles for numerical or categorical variables (default is 'numerical').
variables : str or array_like of str, optional
Variables for which the profiles will be calculated (default is None, which means all of the variables).
size : float, optional
Width of lines in px (default is 2).
alpha : float <0, 1>, optional
Opacity of lines (default is 1).
color : str, optional
Variable name used for grouping (default is '_label_', which groups by models).
facet_ncol : int, optional
Number of columns on the plot grid (default is 2).
show_observations : bool, optional
Show observation points (default is True).
title : str, optional
Title of the plot (default is "Ceteris Paribus Profiles").
y_title : str, optional
Title of the y/x axis (default is "prediction").
horizontal_spacing : float <0, 1>, optional
Ratio of horizontal space between the plots (default depends on variable_type).
vertical_spacing : float <0, 1>, optional
Ratio of vertical space between the plots (default is 0.3/number of rows).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(self,
         objects=None,
         variable_type="numerical",
         variables=None,
         size=2,
         alpha=1,
         color="_label_",
         facet_ncol=2,
         show_observations=True,
         title="Ceteris Paribus Profiles",
         y_title='prediction',
         horizontal_spacing=None,
         vertical_spacing=None,
         show=True):
    """Plot the Ceteris Paribus explanation

    Parameters
    -----------
    objects : CeterisParibus object or array_like of CeterisParibus objects
        Additional objects to plot in subplots (default is `None`).
    variable_type : {'numerical', 'categorical'}
        Plot the profiles for numerical or categorical variables 
        (default is `'numerical'`).
    variables : str or array_like of str, optional
        Variables for which the profiles will be calculated
        (default is `None`, which means all of the variables).
    size : float, optional
        Width of lines in px (default is `2`).
    alpha : float <0, 1>, optional
        Opacity of lines (default is `1`).
    color : str, optional
        Variable name used for grouping
        (default is `'_label_'`, which groups by models).
    facet_ncol : int, optional
        Number of columns on the plot grid (default is `2`).
    show_observations : bool, optional
        Show observation points (default is `True`).
    title : str, optional
        Title of the plot (default is `"Ceteris Paribus Profiles"`).
    y_title : str, optional
        Title of the y/x axis (default is `"prediction"`).
    horizontal_spacing : float <0, 1>, optional
        Ratio of horizontal space between the plots
        (default depends on `variable_type`).
    vertical_spacing : float <0, 1>, optional
        Ratio of vertical space between the plots (default is `0.3/number of rows`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can 
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -----------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    if variable_type not in ("numerical", "categorical"):
        raise TypeError("variable_type should be 'numerical' or 'categorical'")
    if isinstance(variables, str):
        variables = (variables,)

    # are there any other objects to plot?
    if objects is None:
        _result_df = self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_'])
        _include = self.variable_splits_with_obs
    elif isinstance(objects, self.__class__):  # allow for objects to be a single element
        _result_df = pd.concat([
            self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_']),
            objects.result.assign(_original_yhat_=lambda x: objects.new_observation.loc[x.index, '_yhat_'])])
        _include = np.all([self.variable_splits_with_obs, objects.variable_splits_with_obs])
    elif isinstance(objects, (list, tuple)):  # objects as tuple or array
        _result_df = self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_'])
        _include = [self.variable_splits_with_obs]
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _result_df = pd.concat([
                _result_df, ob.result.assign(_original_yhat_=lambda x: ob.new_observation.loc[x.index, '_yhat_'])])
            _include += [ob.variable_splits_with_obs]
        _include = np.all(_include)
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    if _include is False and show_observations:
            warnings.warn("'show_observations' parameter changed to False,"
                          "because the 'variable_splits_with_obs' attribute is False"
                          "See `variable_splits_with_obs` parameter in `predict_profile`.")
            show_observations = False

    # variables to use
    all_variables = list(_result_df['_vname_'].dropna().unique())

    if variables is not None:
        all_variables = _global_utils.intersect_unsorted(variables, all_variables)
        if len(all_variables) == 0:
            raise TypeError("variables do not overlap with " + ''.join(variables))

    # names of numeric variables
    numeric_variables = _result_df[all_variables].select_dtypes(include=np.number).columns.tolist()

    if variable_type == "numerical":
        variable_names = numeric_variables

        if len(variable_names) == 0:
            # change to categorical
            variable_type = "categorical"
            # send message
            warnings.warn("'variable_type' parameter changed to 'categorical' due to lack of numerical variables.")
            # take all
            variable_names = all_variables
        elif variables is not None and len(variable_names) != len(variables):
            raise TypeError("There are no numerical variables")
    else:
        variable_names = np.setdiff1d(all_variables, numeric_variables).tolist()

        # there are variables selected
        if variables is not None:
            # take all
            variable_names = all_variables
        elif len(variable_names) == 0:
            # there were no variables selected and there are no categorical variables
            raise TypeError("There are no non-numerical variables.")

    # prepare profiles data
    _result_df = _result_df.loc[_result_df['_vname_'].isin(variable_names), ].reset_index(drop=True)

    #  calculate y axis range to allow for fixedrange True
    dl = _result_df['_yhat_'].to_numpy()
    min_max_margin = dl.ptp() * 0.10
    min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin]

    # create _x_
    if len(variable_names) == 1:
        _result_df.loc[:, '_x_'] = deepcopy(_result_df.loc[:, variable_names[0]])
    else:
        for variable in variable_names:
            where_variable = _result_df['_vname_'] == variable
            _result_df.loc[where_variable, '_x_'] = deepcopy(_result_df.loc[where_variable, variable])

    # change x column to proper character values
    if variable_type == 'categorical':
        _result_df.loc[:, '_x_'] = _result_df.apply(lambda row: str(row[row['_vname_']]), axis=1)

    n = len(variable_names)
    facet_nrow = int(np.ceil(n / facet_ncol))
    if vertical_spacing is None:
        vertical_spacing = 0.3 / facet_nrow if variable_type == 'numerical' else 0.05
    if horizontal_spacing is None:
        horizontal_spacing = 0.05 if variable_type == 'numerical' else 0.1

    plot_height = 78 + 71 + facet_nrow * (280 + 60)

    _result_df = _result_df.assign(_text_=_result_df.apply(lambda obs: plot.tooltip_text(obs), axis=1))

    if variable_type == "numerical":    
        m = len(_result_df[color].dropna().unique())
        _result_df[color] = _result_df[color].astype(object)  # prevent error when using pd.StringDtype
    
        fig = px.line(_result_df,
                      x="_x_", y="_yhat_", color=color, facet_col="_vname_", line_group='_ids_',
                      category_orders={"_vname_": list(variable_names)},
                      labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},  # , color: 'group'},
                      # hover_data={'_text_': True, '_yhat_': ':.3f', '_vname_': False, '_x_': False, color: False},
                      custom_data=['_text_'],
                      facet_col_wrap=facet_ncol,
                      facet_row_spacing=vertical_spacing,
                      facet_col_spacing=horizontal_spacing,
                      template="none",
                      color_discrete_sequence=_theme.get_default_colors(m, 'line')) \
                .update_traces(dict(line_width=size, opacity=alpha,
                                    hovertemplate="%{customdata[0]}<extra></extra>")) \
                .update_xaxes({'matches': None, 'showticklabels': True,
                               'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                               'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}) \
                .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                               'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
                               'range': min_max})

        if show_observations:
            _points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy()

            fig_points = px.scatter(_points_df,
                                    x='_original_', y='_yhat_', facet_col='_vname_',
                                    category_orders={"_vname_": list(variable_names)},
                                    labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},
                                    custom_data=['_text_'],
                                    facet_col_wrap=facet_ncol,
                                    facet_row_spacing=vertical_spacing,
                                    facet_col_spacing=horizontal_spacing,
                                    color_discrete_sequence=["#371ea3"]) \
                           .update_traces(dict(marker_size=5*size, opacity=alpha),
                                          hovertemplate="%{customdata[0]}<extra></extra>")

            for _, value in enumerate(fig_points.data):
                fig.add_trace(value)
                
        fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, 'closest')

    else:
        if color=="_label_" and len(_result_df['_ids_'].unique()) > 1 and len(_result_df['_label_'].unique()) == 1:
            warnings.warn("'color' parameter changed to '_ids_' because there are multiple observations for one model.")
            color = '_ids_'
        elif color=="_label_" and len(_result_df['_ids_'].unique()) > len(_result_df['_label_'].unique()): 
            # https://github.com/plotly/plotly.py/issues/2657
            raise TypeError("Please pick one observation per label or change the `color` parameter.")

        m = len(_result_df[color].dropna().unique())
        _result_df[color] = _result_df[color].astype(object)  # prevent error when using pd.StringDtype
        
        _result_df = _result_df.assign(_diff_=lambda x: x['_yhat_'] - x['_original_yhat_'])
        fig = px.bar(_result_df,
                     x="_diff_", y="_x_", color=color, facet_col="_vname_",
                     category_orders={"_vname_": list(variable_names)},
                     labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},  # , color: 'group'},
                     # hover_data={'_yhat_': ':.3f', '_ids_': True, '_vname_': False, color: False},
                     custom_data=['_text_'],
                     base="_original_yhat_",
                     facet_col_wrap=facet_ncol,
                     facet_row_spacing=vertical_spacing,
                     facet_col_spacing=horizontal_spacing,
                     template="none",
                     color_discrete_sequence=_theme.get_default_colors(m, 'line'),
                     barmode='group',
                     orientation='h')  \
                .update_traces(dict(opacity=alpha),
                               hovertemplate="%{customdata[0]}<extra></extra>") \
                .update_yaxes({'matches': None, 'showticklabels': True,
                               'type': 'category', 'gridwidth': 2, 'automargin': True, 
                               'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}) \
                .update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                               'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True,
                               'range': min_max})
                
        for _, bar in enumerate(fig.data):
            fig.add_vline(x=bar.base[0], layer='below',
                          line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'})
            
        if show_observations:
            _points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy()

            fig_points = px.scatter(_points_df,
                                    x='_yhat_', y='_x_', facet_col='_vname_',
                                    category_orders={"_vname_": list(variable_names)},
                                    labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'},
                                    custom_data=['_text_'],
                                    facet_col_wrap=facet_ncol,
                                    facet_row_spacing=vertical_spacing,
                                    facet_col_spacing=horizontal_spacing,
                                    color_discrete_sequence=["#371ea3"]) \
                           .update_traces(dict(marker_size=5*size, opacity=alpha),
                                          hovertemplate="%{customdata[0]}<extra></extra>")

            for _, value in enumerate(fig_points.data):
                fig.add_trace(value)

        fig = _theme.fig_update_bar_plot(fig, title, y_title, plot_height, 'closest')
        
    fig.update_layout(hoverlabel=dict(bgcolor='rgba(0,0,0,0.8)'))
    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig
class Shap (path='average', B=25, keep_distributions=False, processes=1, random_state=None)

Calculate predict-level variable attributions as Shapley Values

Parameters

path : list of int, optional
If specified, then attributions for this path will be plotted (default is 'average', which plots attribution means for B random paths).
B : int, optional
Number of random paths to calculate variable attributions (default is 25).
keep_distributions : bool, optional
Save the distribution of partial predictions (default is False).
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).

Attributes

result : pd.DataFrame
Main result attribute of an explanation.
prediction : float
Prediction for new_observation.
intercept : float
Average prediction for data.
path : list of int or 'average'
Path for which the attributions will be plotted.
B : int
Number of random paths to calculate variable attributions.
keep_distributions : bool
Save the distribution of partial predictions.
yhats_distributions : pd.DataFrame or None
The distribution of partial predictions.
processes : int
Number of parallel processes to use in calculations. Iterated over B.
random_state : int or None
Seed that was set for random number generator.

Notes

Expand source code Browse git
class Shap(Explanation):
    """Calculate predict-level variable attributions as Shapley Values

    Parameters
    -----------
    path : list of int, optional
        If specified, then attributions for this path will be plotted
        (default is `'average'`, which plots attribution means for `B` random paths).
    B : int, optional
        Number of random paths to calculate variable attributions (default is `25`).
    keep_distributions : bool, optional
        Save the distribution of partial predictions (default is `False`).
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).

    Attributes
    -----------
    result : pd.DataFrame
        Main result attribute of an explanation.
    prediction : float
        Prediction for `new_observation`.
    intercept : float
        Average prediction for `data`.
    path : list of int or 'average'
        Path for which the attributions will be plotted.
    B : int
        Number of random paths to calculate variable attributions.
    keep_distributions : bool
        Save the distribution of partial predictions.
    yhats_distributions : pd.DataFrame or None
        The distribution of partial predictions.
    processes : int
        Number of parallel processes to use in calculations. Iterated over `B`.
    random_state : int or None
        Seed that was set for random number generator.

    Notes
    --------
    - https://pbiecek.github.io/ema/shapley.html
    """

    def __init__(self,
                 path="average",
                 B=25,
                 keep_distributions=False,
                 processes=1,
                 random_state=None):

        _path = checks.check_path(path)
        _processess = checks.check_processes(processes)
        _random_state = checks.check_random_state(random_state)

        self.path = _path
        self.keep_distributions = keep_distributions
        self.B = B
        self.result = pd.DataFrame()
        self.yhats_distributions = None
        self.prediction = None
        self.intercept = None
        self.processes = _processess
        self.random_state = _random_state

    def _repr_html_(self):
        return self.result._repr_html_()

    def fit(self,
            explainer,
            new_observation):
        """Calculate the result of explanation

        Fit method makes calculations in place and changes the attributes.

        Parameters
        -----------
        explainer : Explainer object
            Model wrapper created using the Explainer class.
        new_observation : pd.Series or np.ndarray
            An observation for which a prediction needs to be explained.

        Returns
        -----------
        None
        """

        _new_observation = checks.check_new_observation(new_observation, explainer)
        checks.check_columns_in_new_observation(_new_observation, explainer)
        self.result, self.prediction, self.intercept, self.yhats_distributions = utils.shap(
            explainer,
            _new_observation,
            self.path,
            self.keep_distributions,
            self.B,
            self.processes,
            self.random_state
        )

    def plot(self,
             objects=None,
             baseline=None,
             max_vars=10,
             digits=3,
             rounding_function=np.around,
             bar_width=16,
             min_max=None,
             vcolors=None,
             title="Shapley Values",
             vertical_spacing=None,
             show=True):
        """Plot the Shapley Values explanation

        Parameters
        -----------
        objects : Shap object or array_like of Shap objects
            Additional objects to plot in subplots (default is `None`).
        baseline: float, optional
            Starting x point for bars (default is average prediction).
        max_vars : int, optional
            Maximum number of variables that will be presented for for each subplot
            (default is `10`).
        digits : int, optional
            Number of decimal places (`np.around`) to round contributions.
            See `rounding_function` parameter (default is `3`).
        rounding_function : function, optional
            A function that will be used for rounding numbers (default is `np.around`).
        bar_width : float, optional
            Width of bars in px (default is `16`).
        min_max : 2-tuple of float, optional
            Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`).
        vcolors : 3-tuple of str, optional
            Color of bars (default is `["#8bdcbe", "#f05a71"]`).
        title : str, optional
            Title of the plot (default is `"Shapley Values"`).
        vertical_spacing : float <0, 1>, optional
            Ratio of vertical space between the plots (default is `0.2/number of rows`).
        show : bool, optional
            `True` shows the plot; `False` returns the plotly Figure object that can 
            be edited or saved using the `write_image()` method (default is `True`).

        Returns
        -----------
        None or plotly.graph_objects.Figure
            Return figure that can be edited or saved. See `show` parameter.
        """

        # are there any other objects to plot?
        if objects is None:
            n = 1
            _result_list = [self.result.loc[self.result['B'] == 0,].copy()]
            _intercept_list = [self.intercept]
            _prediction_list = [self.prediction]
        elif isinstance(objects, self.__class__):  # allow for objects to be a single element
            n = 2
            _result_list = [self.result.loc[self.result['B'] == 0,].copy(),
                            objects.result.loc[objects.result['B'] == 0,].copy()]
            _intercept_list = [self.intercept, objects.intercept]
            _prediction_list = [self.prediction, objects.prediction]
        elif isinstance(objects, (list, tuple)):  # objects as tuple or array
            n = len(objects) + 1
            _result_list = [self.result.loc[self.result['B'] == 0,].copy()]
            _intercept_list = [self.intercept]
            _prediction_list = [self.prediction]
            for ob in objects:
                _global_checks.global_check_object_class(ob, self.__class__)
                _result_list += [ob.result.loc[ob.result['B'] == 0,].copy()]
                _intercept_list += [ob.intercept]
                _prediction_list += [ob.prediction]
        else:
            _global_checks.global_raise_objects_class(objects, self.__class__)

        # TODO: add intercept and prediction list update for multi-class
        # deleted_indexes = []
        # for i in range(n):
        #     result = _result_list[i]
        #
        #     if len(result['label'].unique()) > 1:
        #         n += len(result['label'].unique()) - 1
        #         # add new data frames to list
        #         _result_list += [v for k, v in result.groupby('label', sort=False)]
        #         deleted_indexes += [i]
        #
        # _result_list = [j for i, j in enumerate(_result_list) if i not in deleted_indexes]
        model_names = [result.iloc[0, result.columns.get_loc("label")] for result in _result_list]

        if vertical_spacing is None:
            vertical_spacing = 0.2 / n

        fig = make_subplots(rows=n, cols=1,
                            shared_xaxes=True, vertical_spacing=vertical_spacing,
                            x_title='contribution', subplot_titles=model_names)
        plot_height = 78 + 71

        if vcolors is None:
            vcolors = _theme.get_break_down_colors()

        if min_max is None:
            temp_min_max = [np.Inf, -np.Inf]
        else:
            temp_min_max = min_max

        for i, _result in enumerate(_result_list):
            if _result.shape[0] <= max_vars:
                m = _result.shape[0]
            else:
                m = max_vars + 1

            if baseline is None:
                baseline = _intercept_list[i]
            prediction = _prediction_list[i]

            df = plot.prepare_data_for_shap_plot(_result, baseline, prediction, max_vars, rounding_function, digits)

            fig.add_shape(
                type='line',
                x0=baseline,
                x1=baseline,
                y0=-1,
                y1=m,
                yref="paper",
                xref="x",
                line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'},
                row=i + 1, col=1
            )

            fig.add_bar(
                orientation="h",
                y=df['variable'].tolist(),
                x=df['contribution'].tolist(),
                textposition="outside",
                text=df['label_text'].tolist(),
                marker_color=[vcolors[int(c)] for c in df['sign'].tolist()],
                base=baseline,
                hovertext=df['tooltip_text'].tolist(),
                hoverinfo='text',
                hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'},
                showlegend=False,
                row=i + 1, col=1
            )

            fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True,
                              'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True},
                             row=i + 1, col=1)

            fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                              'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True},
                             row=i + 1, col=1)

            plot_height += m * bar_width + (m + 1) * bar_width / 4

            if min_max is None:
                cum = df.contribution.values + baseline
                min_max_margin = cum.ptp() * 0.15
                temp_min_max[0] = np.min([temp_min_max[0], cum.min() - min_max_margin])
                temp_min_max[1] = np.max([temp_min_max[1], cum.max() + min_max_margin])

        plot_height += (n - 1) * 70

        fig.update_xaxes({'range': temp_min_max})
        fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
                          height=plot_height, margin={'t': 78, 'b': 71, 'r': 30})

        if show:
            fig.show(config=_theme.get_default_config())
        else:
            return fig

Ancestors

  • dalex._explanation.Explanation
  • abc.ABC

Methods

def fit(self, explainer, new_observation)

Calculate the result of explanation

Fit method makes calculations in place and changes the attributes.

Parameters

explainer : Explainer object
Model wrapper created using the Explainer class.
new_observation : pd.Series or np.ndarray
An observation for which a prediction needs to be explained.

Returns

None
 
Expand source code Browse git
def fit(self,
        explainer,
        new_observation):
    """Calculate the result of explanation

    Fit method makes calculations in place and changes the attributes.

    Parameters
    -----------
    explainer : Explainer object
        Model wrapper created using the Explainer class.
    new_observation : pd.Series or np.ndarray
        An observation for which a prediction needs to be explained.

    Returns
    -----------
    None
    """

    _new_observation = checks.check_new_observation(new_observation, explainer)
    checks.check_columns_in_new_observation(_new_observation, explainer)
    self.result, self.prediction, self.intercept, self.yhats_distributions = utils.shap(
        explainer,
        _new_observation,
        self.path,
        self.keep_distributions,
        self.B,
        self.processes,
        self.random_state
    )
def plot(self, objects=None, baseline=None, max_vars=10, digits=3, rounding_function=<function around>, bar_width=16, min_max=None, vcolors=None, title='Shapley Values', vertical_spacing=None, show=True)

Plot the Shapley Values explanation

Parameters

objects : Shap object or array_like of Shap objects
Additional objects to plot in subplots (default is None).
baseline : float, optional
Starting x point for bars (default is average prediction).
max_vars : int, optional
Maximum number of variables that will be presented for for each subplot (default is 10).
digits : int, optional
Number of decimal places (np.around) to round contributions. See rounding_function parameter (default is 3).
rounding_function : function, optional
A function that will be used for rounding numbers (default is np.around).
bar_width : float, optional
Width of bars in px (default is 16).
min_max : 2-tuple of float, optional
Range of OX axis (default is [min-0.15*(max-min), max+0.15*(max-min)]).
vcolors : 3-tuple of str, optional
Color of bars (default is ["#8bdcbe", "#f05a71"]).
title : str, optional
Title of the plot (default is "Shapley Values").
vertical_spacing : float <0, 1>, optional
Ratio of vertical space between the plots (default is 0.2/number of rows).
show : bool, optional
True shows the plot; False returns the plotly Figure object that can be edited or saved using the write_image() method (default is True).

Returns

None or plotly.graph_objects.Figure
Return figure that can be edited or saved. See show parameter.
Expand source code Browse git
def plot(self,
         objects=None,
         baseline=None,
         max_vars=10,
         digits=3,
         rounding_function=np.around,
         bar_width=16,
         min_max=None,
         vcolors=None,
         title="Shapley Values",
         vertical_spacing=None,
         show=True):
    """Plot the Shapley Values explanation

    Parameters
    -----------
    objects : Shap object or array_like of Shap objects
        Additional objects to plot in subplots (default is `None`).
    baseline: float, optional
        Starting x point for bars (default is average prediction).
    max_vars : int, optional
        Maximum number of variables that will be presented for for each subplot
        (default is `10`).
    digits : int, optional
        Number of decimal places (`np.around`) to round contributions.
        See `rounding_function` parameter (default is `3`).
    rounding_function : function, optional
        A function that will be used for rounding numbers (default is `np.around`).
    bar_width : float, optional
        Width of bars in px (default is `16`).
    min_max : 2-tuple of float, optional
        Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`).
    vcolors : 3-tuple of str, optional
        Color of bars (default is `["#8bdcbe", "#f05a71"]`).
    title : str, optional
        Title of the plot (default is `"Shapley Values"`).
    vertical_spacing : float <0, 1>, optional
        Ratio of vertical space between the plots (default is `0.2/number of rows`).
    show : bool, optional
        `True` shows the plot; `False` returns the plotly Figure object that can 
        be edited or saved using the `write_image()` method (default is `True`).

    Returns
    -----------
    None or plotly.graph_objects.Figure
        Return figure that can be edited or saved. See `show` parameter.
    """

    # are there any other objects to plot?
    if objects is None:
        n = 1
        _result_list = [self.result.loc[self.result['B'] == 0,].copy()]
        _intercept_list = [self.intercept]
        _prediction_list = [self.prediction]
    elif isinstance(objects, self.__class__):  # allow for objects to be a single element
        n = 2
        _result_list = [self.result.loc[self.result['B'] == 0,].copy(),
                        objects.result.loc[objects.result['B'] == 0,].copy()]
        _intercept_list = [self.intercept, objects.intercept]
        _prediction_list = [self.prediction, objects.prediction]
    elif isinstance(objects, (list, tuple)):  # objects as tuple or array
        n = len(objects) + 1
        _result_list = [self.result.loc[self.result['B'] == 0,].copy()]
        _intercept_list = [self.intercept]
        _prediction_list = [self.prediction]
        for ob in objects:
            _global_checks.global_check_object_class(ob, self.__class__)
            _result_list += [ob.result.loc[ob.result['B'] == 0,].copy()]
            _intercept_list += [ob.intercept]
            _prediction_list += [ob.prediction]
    else:
        _global_checks.global_raise_objects_class(objects, self.__class__)

    # TODO: add intercept and prediction list update for multi-class
    # deleted_indexes = []
    # for i in range(n):
    #     result = _result_list[i]
    #
    #     if len(result['label'].unique()) > 1:
    #         n += len(result['label'].unique()) - 1
    #         # add new data frames to list
    #         _result_list += [v for k, v in result.groupby('label', sort=False)]
    #         deleted_indexes += [i]
    #
    # _result_list = [j for i, j in enumerate(_result_list) if i not in deleted_indexes]
    model_names = [result.iloc[0, result.columns.get_loc("label")] for result in _result_list]

    if vertical_spacing is None:
        vertical_spacing = 0.2 / n

    fig = make_subplots(rows=n, cols=1,
                        shared_xaxes=True, vertical_spacing=vertical_spacing,
                        x_title='contribution', subplot_titles=model_names)
    plot_height = 78 + 71

    if vcolors is None:
        vcolors = _theme.get_break_down_colors()

    if min_max is None:
        temp_min_max = [np.Inf, -np.Inf]
    else:
        temp_min_max = min_max

    for i, _result in enumerate(_result_list):
        if _result.shape[0] <= max_vars:
            m = _result.shape[0]
        else:
            m = max_vars + 1

        if baseline is None:
            baseline = _intercept_list[i]
        prediction = _prediction_list[i]

        df = plot.prepare_data_for_shap_plot(_result, baseline, prediction, max_vars, rounding_function, digits)

        fig.add_shape(
            type='line',
            x0=baseline,
            x1=baseline,
            y0=-1,
            y1=m,
            yref="paper",
            xref="x",
            line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'},
            row=i + 1, col=1
        )

        fig.add_bar(
            orientation="h",
            y=df['variable'].tolist(),
            x=df['contribution'].tolist(),
            textposition="outside",
            text=df['label_text'].tolist(),
            marker_color=[vcolors[int(c)] for c in df['sign'].tolist()],
            base=baseline,
            hovertext=df['tooltip_text'].tolist(),
            hoverinfo='text',
            hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'},
            showlegend=False,
            row=i + 1, col=1
        )

        fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True,
                          'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True},
                         row=i + 1, col=1)

        fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True,
                          'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True},
                         row=i + 1, col=1)

        plot_height += m * bar_width + (m + 1) * bar_width / 4

        if min_max is None:
            cum = df.contribution.values + baseline
            min_max_margin = cum.ptp() * 0.15
            temp_min_max[0] = np.min([temp_min_max[0], cum.min() - min_max_margin])
            temp_min_max[1] = np.max([temp_min_max[1], cum.max() + min_max_margin])

    plot_height += (n - 1) * 70

    fig.update_xaxes({'range': temp_min_max})
    fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none",
                      height=plot_height, margin={'t': 78, 'b': 71, 'r': 30})

    if show:
        fig.show(config=_theme.get_default_config())
    else:
        return fig