Module dalex.predict_explanations
Expand source code Browse git
from ._break_down.object import BreakDown
from ._ceteris_paribus.object import CeterisParibus
from ._shap.object import Shap
__all__ = [
"BreakDown",
"CeterisParibus",
"Shap"
]
Classes
class BreakDown (type='break_down', order=None, interaction_preference=1, keep_distributions=False)
-
Calculate predict-level variable attributions as Break Down
Parameters
type
:{'break_down_interactions', 'break_down'}
- Type of variable attributions (default is
'break_down_interactions'
). order
:list
ofint
orstr
, optional- Use a fixed order of variables for attribution calculation. Use integer values
or string variable names (default is
None
which means order by importance). interaction_preference
:int
, optional- Specify which interactions will be present in an explanation. The larger the
integer, the more frequently interactions will be presented (default is
1
). keep_distributions
:bool
, optional- Save the distribution of partial predictions (default is
False
).
Attributes
result
:pd.DataFrame
- Main result attribute of an explanation.
type
:{'break_down_interactions', 'break_down'}
- Type of variable attributions.
order
:list
ofint
orstr
orNone
- Order of variables used in attribution calculation.
interaction_preference
:int
- Frequency of interaction use.
keep_distributions
:bool
- Save the distribution of partial predictions.
yhats_distributions
:pd.DataFrame
orNone
- The distribution of partial predictions.
Notes
Expand source code Browse git
class BreakDown(Explanation): """Calculate predict-level variable attributions as Break Down Parameters ----------- type : {'break_down_interactions', 'break_down'} Type of variable attributions (default is `'break_down_interactions'`). order : list of int or str, optional Use a fixed order of variables for attribution calculation. Use integer values or string variable names (default is `None` which means order by importance). interaction_preference : int, optional Specify which interactions will be present in an explanation. The larger the integer, the more frequently interactions will be presented (default is `1`). keep_distributions : bool, optional Save the distribution of partial predictions (default is `False`). Attributes ----------- result : pd.DataFrame Main result attribute of an explanation. type : {'break_down_interactions', 'break_down'} Type of variable attributions. order : list of int or str or None Order of variables used in attribution calculation. interaction_preference : int Frequency of interaction use. keep_distributions : bool Save the distribution of partial predictions. yhats_distributions : pd.DataFrame or None The distribution of partial predictions. Notes -------- - https://pbiecek.github.io/ema/breakDown.html - https://pbiecek.github.io/ema/iBreakDown.html """ def __init__(self, type='break_down', order=None, interaction_preference=1, keep_distributions=False): _order = checks.check_order(order) self.type = type self.keep_distributions = keep_distributions self.order = _order self.interaction_preference = interaction_preference self.result = pd.DataFrame() self.yhats_distributions = None def _repr_html_(self): return self.result._repr_html_() def fit(self, explainer, new_observation): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.Series or np.ndarray An observation for which a prediction needs to be explained. Returns ----------- None """ _new_observation = checks.check_new_observation(new_observation, explainer) if _new_observation.shape[0] != 1: warnings.warn("You should pass only one new_observation, taken only first") _new_observation = _new_observation.iloc[0, :] if self.type == 'break_down_interactions': result, yhats_distributions = utils.local_interactions( explainer, _new_observation, self.interaction_preference, '2d', self.order, self.keep_distributions ) elif self.type == 'break_down': result, yhats_distributions = utils.local_interactions( explainer, _new_observation, self.interaction_preference, '1d', self.order, self.keep_distributions ) else: raise ValueError("'type' must be one of {'break_down_interactions', 'break_down'}") self.result = result self.yhats_distributions = yhats_distributions def plot(self, objects=None, baseline=None, max_vars=10, digits=3, rounding_function=np.around, bar_width=16, min_max=None, vcolors=None, title="Break Down", vertical_spacing=None, show=True): """Plot the Break Down explanation Parameters ----------- objects : BreakDown object or array_like of BreakDown objects Additional objects to plot in subplots (default is `None`). baseline: float, optional Starting x point for bars (default is average prediction). max_vars : int, optional Maximum number of variables that will be presented for for each subplot (default is `10`). digits : int, optional Number of decimal places (`np.around`) to round contributions. See `rounding_function` parameter (default is `3`). rounding_function : function, optional A function that will be used for rounding numbers (default is `np.around`). bar_width : float, optional Width of bars in px (default is `16`). min_max : 2-tuple of float, optional Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`). vcolors : 3-tuple of str, optional Color of bars (default is `["#371ea3", "#8bdcbe", "#f05a71"]`). title : str, optional Title of the plot (default is `"Break Down"`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.2/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ # are there any other objects to plot? if objects is None: n = 1 _result_list = [self.result.copy()] elif isinstance(objects, self.__class__): # allow for objects to be a single element n = 2 _result_list = [self.result.copy(), objects.result.copy()] elif isinstance(objects, (list, tuple)): # objects as tuple or array n = len(objects) + 1 _result_list = [self.result.copy()] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_list += [ob.result.copy()] else: _global_checks.global_raise_objects_class(objects, self.__class__) deleted_indexes = [] for i, _result in enumerate(_result_list): if len(_result['label'].unique()) > 1: n += len(_result['label'].unique()) - 1 # add new data frames to list _result_list += [v for k, v in _result.groupby('label', sort=False)] deleted_indexes += [i] _result_list = [val for i, val in enumerate(_result_list) if i not in deleted_indexes] model_names = [result.iloc[0, result.columns.get_loc("label")] for result in _result_list] if vertical_spacing is None: vertical_spacing = 0.2 / n fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='contribution', subplot_titles=model_names) plot_height = 78 + 71 if vcolors is None: vcolors = _theme.get_break_down_colors() if min_max is None: temp_min_max = [np.Inf, -np.Inf] else: temp_min_max = min_max for i, _result in enumerate(_result_list): if _result.shape[0] - 2 <= max_vars: m = _result.shape[0] else: m = max_vars + 3 if baseline is None: baseline = _result.iloc[0, _result.columns.get_loc("cumulative")] df = plot.prepare_data_for_break_down_plot(_result, baseline, max_vars, rounding_function, digits) measure = ["relative"] * m measure[m - 1] = "total" fig.add_shape( type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x", line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1 ) fig.add_waterfall( orientation="h", measure=measure, y=df['variable'].tolist(), x=df['contribution'].tolist(), textposition="outside", text=df['label_text'].tolist(), connector={"mode": "spanning", "line": {"width": 1, "color": "#371ea3", "dash": "solid"}}, decreasing={"marker": {"color": vcolors[-1]}}, increasing={"marker": {"color": vcolors[1]}}, totals={"marker": {"color": vcolors[0]}}, base=baseline, hovertext=df['tooltip_text'].tolist(), hoverinfo='text+delta', hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'}, showlegend=False, row=i + 1, col=1 ) fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}, row=i + 1, col=1) if min_max is None: cum = df.cumulative.values min_max_margin = cum.ptp() * 0.15 temp_min_max[0] = np.min([temp_min_max[0], cum.min() - min_max_margin]) temp_min_max[1] = np.max([temp_min_max[1], cum.max() + min_max_margin]) fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}, row=i + 1, col=1) plot_height += m * bar_width + (m + 1) * bar_width / 4 plot_height += (n - 1) * 70 fig.update_xaxes({'range': temp_min_max}) fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none", height=plot_height, margin={'t': 78, 'b': 71, 'r': 30}) if show: fig.show(config=_theme.get_default_config()) else: return fig
Ancestors
- dalex._explanation.Explanation
- abc.ABC
Methods
def fit(self, explainer, new_observation)
-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
explainer
:Explainer object
- Model wrapper created using the Explainer class.
new_observation
:pd.Series
ornp.ndarray
- An observation for which a prediction needs to be explained.
Returns
None
Expand source code Browse git
def fit(self, explainer, new_observation): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.Series or np.ndarray An observation for which a prediction needs to be explained. Returns ----------- None """ _new_observation = checks.check_new_observation(new_observation, explainer) if _new_observation.shape[0] != 1: warnings.warn("You should pass only one new_observation, taken only first") _new_observation = _new_observation.iloc[0, :] if self.type == 'break_down_interactions': result, yhats_distributions = utils.local_interactions( explainer, _new_observation, self.interaction_preference, '2d', self.order, self.keep_distributions ) elif self.type == 'break_down': result, yhats_distributions = utils.local_interactions( explainer, _new_observation, self.interaction_preference, '1d', self.order, self.keep_distributions ) else: raise ValueError("'type' must be one of {'break_down_interactions', 'break_down'}") self.result = result self.yhats_distributions = yhats_distributions
def plot(self, objects=None, baseline=None, max_vars=10, digits=3, rounding_function=<function around>, bar_width=16, min_max=None, vcolors=None, title='Break Down', vertical_spacing=None, show=True)
-
Plot the Break Down explanation
Parameters
objects
:BreakDown object
orarray_like
ofBreakDown objects
- Additional objects to plot in subplots (default is
None
). baseline
:float
, optional- Starting x point for bars (default is average prediction).
max_vars
:int
, optional- Maximum number of variables that will be presented for for each subplot
(default is
10
). digits
:int
, optional- Number of decimal places (
np.around
) to round contributions. Seerounding_function
parameter (default is3
). rounding_function
:function
, optional- A function that will be used for rounding numbers (default is
np.around
). bar_width
:float
, optional- Width of bars in px (default is
16
). min_max
:2-tuple
offloat
, optional- Range of OX axis (default is
[min-0.15*(max-min), max+0.15*(max-min)]
). vcolors
:3-tuple
ofstr
, optional- Color of bars (default is
["#371ea3", "#8bdcbe", "#f05a71"]
). title
:str
, optional- Title of the plot (default is
"Break Down"
). vertical_spacing
:float <0, 1>
, optional- Ratio of vertical space between the plots (default is
0.2/number of rows
). show
:bool
, optionalTrue
shows the plot;False
returns the plotly Figure object that can be edited or saved using thewrite_image()
method (default isTrue
).
Returns
None
orplotly.graph_objects.Figure
- Return figure that can be edited or saved. See
show
parameter.
Expand source code Browse git
def plot(self, objects=None, baseline=None, max_vars=10, digits=3, rounding_function=np.around, bar_width=16, min_max=None, vcolors=None, title="Break Down", vertical_spacing=None, show=True): """Plot the Break Down explanation Parameters ----------- objects : BreakDown object or array_like of BreakDown objects Additional objects to plot in subplots (default is `None`). baseline: float, optional Starting x point for bars (default is average prediction). max_vars : int, optional Maximum number of variables that will be presented for for each subplot (default is `10`). digits : int, optional Number of decimal places (`np.around`) to round contributions. See `rounding_function` parameter (default is `3`). rounding_function : function, optional A function that will be used for rounding numbers (default is `np.around`). bar_width : float, optional Width of bars in px (default is `16`). min_max : 2-tuple of float, optional Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`). vcolors : 3-tuple of str, optional Color of bars (default is `["#371ea3", "#8bdcbe", "#f05a71"]`). title : str, optional Title of the plot (default is `"Break Down"`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.2/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ # are there any other objects to plot? if objects is None: n = 1 _result_list = [self.result.copy()] elif isinstance(objects, self.__class__): # allow for objects to be a single element n = 2 _result_list = [self.result.copy(), objects.result.copy()] elif isinstance(objects, (list, tuple)): # objects as tuple or array n = len(objects) + 1 _result_list = [self.result.copy()] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_list += [ob.result.copy()] else: _global_checks.global_raise_objects_class(objects, self.__class__) deleted_indexes = [] for i, _result in enumerate(_result_list): if len(_result['label'].unique()) > 1: n += len(_result['label'].unique()) - 1 # add new data frames to list _result_list += [v for k, v in _result.groupby('label', sort=False)] deleted_indexes += [i] _result_list = [val for i, val in enumerate(_result_list) if i not in deleted_indexes] model_names = [result.iloc[0, result.columns.get_loc("label")] for result in _result_list] if vertical_spacing is None: vertical_spacing = 0.2 / n fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='contribution', subplot_titles=model_names) plot_height = 78 + 71 if vcolors is None: vcolors = _theme.get_break_down_colors() if min_max is None: temp_min_max = [np.Inf, -np.Inf] else: temp_min_max = min_max for i, _result in enumerate(_result_list): if _result.shape[0] - 2 <= max_vars: m = _result.shape[0] else: m = max_vars + 3 if baseline is None: baseline = _result.iloc[0, _result.columns.get_loc("cumulative")] df = plot.prepare_data_for_break_down_plot(_result, baseline, max_vars, rounding_function, digits) measure = ["relative"] * m measure[m - 1] = "total" fig.add_shape( type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x", line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1 ) fig.add_waterfall( orientation="h", measure=measure, y=df['variable'].tolist(), x=df['contribution'].tolist(), textposition="outside", text=df['label_text'].tolist(), connector={"mode": "spanning", "line": {"width": 1, "color": "#371ea3", "dash": "solid"}}, decreasing={"marker": {"color": vcolors[-1]}}, increasing={"marker": {"color": vcolors[1]}}, totals={"marker": {"color": vcolors[0]}}, base=baseline, hovertext=df['tooltip_text'].tolist(), hoverinfo='text+delta', hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'}, showlegend=False, row=i + 1, col=1 ) fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}, row=i + 1, col=1) if min_max is None: cum = df.cumulative.values min_max_margin = cum.ptp() * 0.15 temp_min_max[0] = np.min([temp_min_max[0], cum.min() - min_max_margin]) temp_min_max[1] = np.max([temp_min_max[1], cum.max() + min_max_margin]) fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}, row=i + 1, col=1) plot_height += m * bar_width + (m + 1) * bar_width / 4 plot_height += (n - 1) * 70 fig.update_xaxes({'range': temp_min_max}) fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none", height=plot_height, margin={'t': 78, 'b': 71, 'r': 30}) if show: fig.show(config=_theme.get_default_config()) else: return fig
class CeterisParibus (variables=None, grid_points=101, variable_splits=None, variable_splits_type='uniform', variable_splits_with_obs=False, processes=1)
-
Calculate predict-level variable profiles as Ceteris Paribus
Parameters
variables
:array_like
ofstr
, optional- Variables for which the profiles will be calculated
(default is
None
, which means all of the variables). grid_points
:int
, optional- Maximum number of points for profile calculations (default is
101
). NOTE: The final number of points may be lower thangrid_points
, eg. if there is not enough unique values for a given variable. variable_splits
:dict
oflists
, optional- Split points for variables e.g.
{'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']}
(default isNone
, which means that they will be calculated using one ofvariable_splits_type
and thedata
attribute). variable_splits_type
:{'uniform', 'quantiles'}
, optional- Way of calculating
variable_splits
. Set'quantiles'
for percentiles. (default is'uniform'
, which means uniform grid of points). variable_splits_with_obs
:bool
, optional- Add variable values of
new_observation
data to thevariable_splits
(default isTrue
). processes
:int
, optional- Number of parallel processes to use in calculations. Iterated over
variables
(default is1
, which means no parallel computation).
Attributes
result
:pd.DataFrame
- Main result attribute of an explanation.
new_observation
:pd.DataFrame
- Observations for which predictions need to be explained.
variables
:array_like
ofstr
orNone
- Variables for which the profiles will be calculated.
grid_points
:int
- Maximum number of points for profile calculations.
variable_splits
:dict
oflists
orNone
- Split points for variables.
variable_splits_type
:{'uniform', 'quantiles'}
- Way of calculating
variable_splits
. variable_splits_with_obs
:bool
- Add variable values of
new_observation
data to thevariable_splits
. processes
:int
- Number of parallel processes to use in calculations. Iterated over
B
.
Notes
Expand source code Browse git
class CeterisParibus(Explanation): """Calculate predict-level variable profiles as Ceteris Paribus Parameters ----------- variables : array_like of str, optional Variables for which the profiles will be calculated (default is `None`, which means all of the variables). grid_points : int, optional Maximum number of points for profile calculations (default is `101`). NOTE: The final number of points may be lower than `grid_points`, eg. if there is not enough unique values for a given variable. variable_splits : dict of lists, optional Split points for variables e.g. `{'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']}` (default is `None`, which means that they will be calculated using one of `variable_splits_type` and the `data` attribute). variable_splits_type : {'uniform', 'quantiles'}, optional Way of calculating `variable_splits`. Set `'quantiles'` for percentiles. (default is `'uniform'`, which means uniform grid of points). variable_splits_with_obs: bool, optional Add variable values of `new_observation` data to the `variable_splits` (default is `True`). processes : int, optional Number of parallel processes to use in calculations. Iterated over `variables` (default is `1`, which means no parallel computation). Attributes ----------- result : pd.DataFrame Main result attribute of an explanation. new_observation : pd.DataFrame Observations for which predictions need to be explained. variables : array_like of str or None Variables for which the profiles will be calculated. grid_points : int Maximum number of points for profile calculations. variable_splits : dict of lists or None Split points for variables. variable_splits_type : {'uniform', 'quantiles'} Way of calculating `variable_splits`. variable_splits_with_obs: bool Add variable values of `new_observation` data to the `variable_splits`. processes : int Number of parallel processes to use in calculations. Iterated over `B`. Notes -------- - https://pbiecek.github.io/ema/ceterisParibus.html """ def __init__(self, variables=None, grid_points=101, variable_splits=None, variable_splits_type='uniform', variable_splits_with_obs=False, processes=1): _processes = checks.check_processes(processes) _variable_splits_type = checks.check_variable_splits_type(variable_splits_type) self.variables = variables self.grid_points = grid_points self.variable_splits = variable_splits self.variable_splits_type = _variable_splits_type self.variable_splits_with_obs = variable_splits_with_obs self.result = pd.DataFrame() self.new_observation = None self.processes = _processes def _repr_html_(self): return self.result._repr_html_() def fit(self, explainer, new_observation, y=None, verbose=True): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.DataFrame or np.ndarray Observations for which predictions need to be explained. y : pd.Series or np.ndarray (1d), optional Target variable with the same length as `new_observation`. verbose : bool, optional Print tqdm progress bar (default is `True`). Returns ----------- None """ self.variables = checks.check_variables(self.variables, explainer, self.variable_splits) checks.check_data(explainer.data, self.variables) self.new_observation = checks.check_new_observation(new_observation, explainer) self.variable_splits = checks.check_variable_splits(self.variable_splits, self.variables, self.grid_points, explainer.data, self.variable_splits_type, self.variable_splits_with_obs, self.new_observation) y = checks.check_y(y) self.result, self.new_observation = utils.calculate_ceteris_paribus( explainer, self.new_observation, self.variable_splits, y, self.processes, verbose ) def plot(self, objects=None, variable_type="numerical", variables=None, size=2, alpha=1, color="_label_", facet_ncol=2, show_observations=True, title="Ceteris Paribus Profiles", y_title='prediction', horizontal_spacing=None, vertical_spacing=None, show=True): """Plot the Ceteris Paribus explanation Parameters ----------- objects : CeterisParibus object or array_like of CeterisParibus objects Additional objects to plot in subplots (default is `None`). variable_type : {'numerical', 'categorical'} Plot the profiles for numerical or categorical variables (default is `'numerical'`). variables : str or array_like of str, optional Variables for which the profiles will be calculated (default is `None`, which means all of the variables). size : float, optional Width of lines in px (default is `2`). alpha : float <0, 1>, optional Opacity of lines (default is `1`). color : str, optional Variable name used for grouping (default is `'_label_'`, which groups by models). facet_ncol : int, optional Number of columns on the plot grid (default is `2`). show_observations : bool, optional Show observation points (default is `True`). title : str, optional Title of the plot (default is `"Ceteris Paribus Profiles"`). y_title : str, optional Title of the y/x axis (default is `"prediction"`). horizontal_spacing : float <0, 1>, optional Ratio of horizontal space between the plots (default depends on `variable_type`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.3/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ if variable_type not in ("numerical", "categorical"): raise TypeError("variable_type should be 'numerical' or 'categorical'") if isinstance(variables, str): variables = (variables,) # are there any other objects to plot? if objects is None: _result_df = self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_']) _include = self.variable_splits_with_obs elif isinstance(objects, self.__class__): # allow for objects to be a single element _result_df = pd.concat([ self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_']), objects.result.assign(_original_yhat_=lambda x: objects.new_observation.loc[x.index, '_yhat_'])]) _include = np.all([self.variable_splits_with_obs, objects.variable_splits_with_obs]) elif isinstance(objects, (list, tuple)): # objects as tuple or array _result_df = self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_']) _include = [self.variable_splits_with_obs] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_df = pd.concat([ _result_df, ob.result.assign(_original_yhat_=lambda x: ob.new_observation.loc[x.index, '_yhat_'])]) _include += [ob.variable_splits_with_obs] _include = np.all(_include) else: _global_checks.global_raise_objects_class(objects, self.__class__) if _include is False and show_observations: warnings.warn("'show_observations' parameter changed to False," "because the 'variable_splits_with_obs' attribute is False" "See `variable_splits_with_obs` parameter in `predict_profile`.") show_observations = False # variables to use all_variables = list(_result_df['_vname_'].dropna().unique()) if variables is not None: all_variables = _global_utils.intersect_unsorted(variables, all_variables) if len(all_variables) == 0: raise TypeError("variables do not overlap with " + ''.join(variables)) # names of numeric variables numeric_variables = _result_df[all_variables].select_dtypes(include=np.number).columns.tolist() if variable_type == "numerical": variable_names = numeric_variables if len(variable_names) == 0: # change to categorical variable_type = "categorical" # send message warnings.warn("'variable_type' parameter changed to 'categorical' due to lack of numerical variables.") # take all variable_names = all_variables elif variables is not None and len(variable_names) != len(variables): raise TypeError("There are no numerical variables") else: variable_names = np.setdiff1d(all_variables, numeric_variables).tolist() # there are variables selected if variables is not None: # take all variable_names = all_variables elif len(variable_names) == 0: # there were no variables selected and there are no categorical variables raise TypeError("There are no non-numerical variables.") # prepare profiles data _result_df = _result_df.loc[_result_df['_vname_'].isin(variable_names), ].reset_index(drop=True) # calculate y axis range to allow for fixedrange True dl = _result_df['_yhat_'].to_numpy() min_max_margin = dl.ptp() * 0.10 min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin] # create _x_ if len(variable_names) == 1: _result_df.loc[:, '_x_'] = deepcopy(_result_df.loc[:, variable_names[0]]) else: for variable in variable_names: where_variable = _result_df['_vname_'] == variable _result_df.loc[where_variable, '_x_'] = deepcopy(_result_df.loc[where_variable, variable]) # change x column to proper character values if variable_type == 'categorical': _result_df.loc[:, '_x_'] = _result_df.apply(lambda row: str(row[row['_vname_']]), axis=1) n = len(variable_names) facet_nrow = int(np.ceil(n / facet_ncol)) if vertical_spacing is None: vertical_spacing = 0.3 / facet_nrow if variable_type == 'numerical' else 0.05 if horizontal_spacing is None: horizontal_spacing = 0.05 if variable_type == 'numerical' else 0.1 plot_height = 78 + 71 + facet_nrow * (280 + 60) _result_df = _result_df.assign(_text_=_result_df.apply(lambda obs: plot.tooltip_text(obs), axis=1)) if variable_type == "numerical": m = len(_result_df[color].dropna().unique()) _result_df[color] = _result_df[color].astype(object) # prevent error when using pd.StringDtype fig = px.line(_result_df, x="_x_", y="_yhat_", color=color, facet_col="_vname_", line_group='_ids_', category_orders={"_vname_": list(variable_names)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'}, # , color: 'group'}, # hover_data={'_text_': True, '_yhat_': ':.3f', '_vname_': False, '_x_': False, color: False}, custom_data=['_text_'], facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, template="none", color_discrete_sequence=_theme.get_default_colors(m, 'line')) \ .update_traces(dict(line_width=size, opacity=alpha, hovertemplate="%{customdata[0]}<extra></extra>")) \ .update_xaxes({'matches': None, 'showticklabels': True, 'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}) \ .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True, 'range': min_max}) if show_observations: _points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy() fig_points = px.scatter(_points_df, x='_original_', y='_yhat_', facet_col='_vname_', category_orders={"_vname_": list(variable_names)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'}, custom_data=['_text_'], facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, color_discrete_sequence=["#371ea3"]) \ .update_traces(dict(marker_size=5*size, opacity=alpha), hovertemplate="%{customdata[0]}<extra></extra>") for _, value in enumerate(fig_points.data): fig.add_trace(value) fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, 'closest') else: if color=="_label_" and len(_result_df['_ids_'].unique()) > 1 and len(_result_df['_label_'].unique()) == 1: warnings.warn("'color' parameter changed to '_ids_' because there are multiple observations for one model.") color = '_ids_' elif color=="_label_" and len(_result_df['_ids_'].unique()) > len(_result_df['_label_'].unique()): # https://github.com/plotly/plotly.py/issues/2657 raise TypeError("Please pick one observation per label or change the `color` parameter.") m = len(_result_df[color].dropna().unique()) _result_df[color] = _result_df[color].astype(object) # prevent error when using pd.StringDtype _result_df = _result_df.assign(_diff_=lambda x: x['_yhat_'] - x['_original_yhat_']) fig = px.bar(_result_df, x="_diff_", y="_x_", color=color, facet_col="_vname_", category_orders={"_vname_": list(variable_names)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'}, # , color: 'group'}, # hover_data={'_yhat_': ':.3f', '_ids_': True, '_vname_': False, color: False}, custom_data=['_text_'], base="_original_yhat_", facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, template="none", color_discrete_sequence=_theme.get_default_colors(m, 'line'), barmode='group', orientation='h') \ .update_traces(dict(opacity=alpha), hovertemplate="%{customdata[0]}<extra></extra>") \ .update_yaxes({'matches': None, 'showticklabels': True, 'type': 'category', 'gridwidth': 2, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}) \ .update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True, 'range': min_max}) for _, bar in enumerate(fig.data): fig.add_vline(x=bar.base[0], layer='below', line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}) if show_observations: _points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy() fig_points = px.scatter(_points_df, x='_yhat_', y='_x_', facet_col='_vname_', category_orders={"_vname_": list(variable_names)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'}, custom_data=['_text_'], facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, color_discrete_sequence=["#371ea3"]) \ .update_traces(dict(marker_size=5*size, opacity=alpha), hovertemplate="%{customdata[0]}<extra></extra>") for _, value in enumerate(fig_points.data): fig.add_trace(value) fig = _theme.fig_update_bar_plot(fig, title, y_title, plot_height, 'closest') fig.update_layout(hoverlabel=dict(bgcolor='rgba(0,0,0,0.8)')) if show: fig.show(config=_theme.get_default_config()) else: return fig
Ancestors
- dalex._explanation.Explanation
- abc.ABC
Methods
def fit(self, explainer, new_observation, y=None, verbose=True)
-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
explainer
:Explainer object
- Model wrapper created using the Explainer class.
new_observation
:pd.DataFrame
ornp.ndarray
- Observations for which predictions need to be explained.
y
:pd.Series
ornp.ndarray (1d)
, optional- Target variable with the same length as
new_observation
. verbose
:bool
, optional- Print tqdm progress bar (default is
True
).
Returns
None
Expand source code Browse git
def fit(self, explainer, new_observation, y=None, verbose=True): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.DataFrame or np.ndarray Observations for which predictions need to be explained. y : pd.Series or np.ndarray (1d), optional Target variable with the same length as `new_observation`. verbose : bool, optional Print tqdm progress bar (default is `True`). Returns ----------- None """ self.variables = checks.check_variables(self.variables, explainer, self.variable_splits) checks.check_data(explainer.data, self.variables) self.new_observation = checks.check_new_observation(new_observation, explainer) self.variable_splits = checks.check_variable_splits(self.variable_splits, self.variables, self.grid_points, explainer.data, self.variable_splits_type, self.variable_splits_with_obs, self.new_observation) y = checks.check_y(y) self.result, self.new_observation = utils.calculate_ceteris_paribus( explainer, self.new_observation, self.variable_splits, y, self.processes, verbose )
def plot(self, objects=None, variable_type='numerical', variables=None, size=2, alpha=1, color='_label_', facet_ncol=2, show_observations=True, title='Ceteris Paribus Profiles', y_title='prediction', horizontal_spacing=None, vertical_spacing=None, show=True)
-
Plot the Ceteris Paribus explanation
Parameters
objects
:CeterisParibus object
orarray_like
ofCeterisParibus objects
- Additional objects to plot in subplots (default is
None
). variable_type
:{'numerical', 'categorical'}
- Plot the profiles for numerical or categorical variables
(default is
'numerical'
). variables
:str
orarray_like
ofstr
, optional- Variables for which the profiles will be calculated
(default is
None
, which means all of the variables). size
:float
, optional- Width of lines in px (default is
2
). alpha
:float <0, 1>
, optional- Opacity of lines (default is
1
). color
:str
, optional- Variable name used for grouping
(default is
'_label_'
, which groups by models). facet_ncol
:int
, optional- Number of columns on the plot grid (default is
2
). show_observations
:bool
, optional- Show observation points (default is
True
). title
:str
, optional- Title of the plot (default is
"Ceteris Paribus Profiles"
). y_title
:str
, optional- Title of the y/x axis (default is
"prediction"
). horizontal_spacing
:float <0, 1>
, optional- Ratio of horizontal space between the plots
(default depends on
variable_type
). vertical_spacing
:float <0, 1>
, optional- Ratio of vertical space between the plots (default is
0.3/number of rows
). show
:bool
, optionalTrue
shows the plot;False
returns the plotly Figure object that can be edited or saved using thewrite_image()
method (default isTrue
).
Returns
None
orplotly.graph_objects.Figure
- Return figure that can be edited or saved. See
show
parameter.
Expand source code Browse git
def plot(self, objects=None, variable_type="numerical", variables=None, size=2, alpha=1, color="_label_", facet_ncol=2, show_observations=True, title="Ceteris Paribus Profiles", y_title='prediction', horizontal_spacing=None, vertical_spacing=None, show=True): """Plot the Ceteris Paribus explanation Parameters ----------- objects : CeterisParibus object or array_like of CeterisParibus objects Additional objects to plot in subplots (default is `None`). variable_type : {'numerical', 'categorical'} Plot the profiles for numerical or categorical variables (default is `'numerical'`). variables : str or array_like of str, optional Variables for which the profiles will be calculated (default is `None`, which means all of the variables). size : float, optional Width of lines in px (default is `2`). alpha : float <0, 1>, optional Opacity of lines (default is `1`). color : str, optional Variable name used for grouping (default is `'_label_'`, which groups by models). facet_ncol : int, optional Number of columns on the plot grid (default is `2`). show_observations : bool, optional Show observation points (default is `True`). title : str, optional Title of the plot (default is `"Ceteris Paribus Profiles"`). y_title : str, optional Title of the y/x axis (default is `"prediction"`). horizontal_spacing : float <0, 1>, optional Ratio of horizontal space between the plots (default depends on `variable_type`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.3/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ if variable_type not in ("numerical", "categorical"): raise TypeError("variable_type should be 'numerical' or 'categorical'") if isinstance(variables, str): variables = (variables,) # are there any other objects to plot? if objects is None: _result_df = self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_']) _include = self.variable_splits_with_obs elif isinstance(objects, self.__class__): # allow for objects to be a single element _result_df = pd.concat([ self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_']), objects.result.assign(_original_yhat_=lambda x: objects.new_observation.loc[x.index, '_yhat_'])]) _include = np.all([self.variable_splits_with_obs, objects.variable_splits_with_obs]) elif isinstance(objects, (list, tuple)): # objects as tuple or array _result_df = self.result.assign(_original_yhat_=lambda x: self.new_observation.loc[x.index, '_yhat_']) _include = [self.variable_splits_with_obs] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_df = pd.concat([ _result_df, ob.result.assign(_original_yhat_=lambda x: ob.new_observation.loc[x.index, '_yhat_'])]) _include += [ob.variable_splits_with_obs] _include = np.all(_include) else: _global_checks.global_raise_objects_class(objects, self.__class__) if _include is False and show_observations: warnings.warn("'show_observations' parameter changed to False," "because the 'variable_splits_with_obs' attribute is False" "See `variable_splits_with_obs` parameter in `predict_profile`.") show_observations = False # variables to use all_variables = list(_result_df['_vname_'].dropna().unique()) if variables is not None: all_variables = _global_utils.intersect_unsorted(variables, all_variables) if len(all_variables) == 0: raise TypeError("variables do not overlap with " + ''.join(variables)) # names of numeric variables numeric_variables = _result_df[all_variables].select_dtypes(include=np.number).columns.tolist() if variable_type == "numerical": variable_names = numeric_variables if len(variable_names) == 0: # change to categorical variable_type = "categorical" # send message warnings.warn("'variable_type' parameter changed to 'categorical' due to lack of numerical variables.") # take all variable_names = all_variables elif variables is not None and len(variable_names) != len(variables): raise TypeError("There are no numerical variables") else: variable_names = np.setdiff1d(all_variables, numeric_variables).tolist() # there are variables selected if variables is not None: # take all variable_names = all_variables elif len(variable_names) == 0: # there were no variables selected and there are no categorical variables raise TypeError("There are no non-numerical variables.") # prepare profiles data _result_df = _result_df.loc[_result_df['_vname_'].isin(variable_names), ].reset_index(drop=True) # calculate y axis range to allow for fixedrange True dl = _result_df['_yhat_'].to_numpy() min_max_margin = dl.ptp() * 0.10 min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin] # create _x_ if len(variable_names) == 1: _result_df.loc[:, '_x_'] = deepcopy(_result_df.loc[:, variable_names[0]]) else: for variable in variable_names: where_variable = _result_df['_vname_'] == variable _result_df.loc[where_variable, '_x_'] = deepcopy(_result_df.loc[where_variable, variable]) # change x column to proper character values if variable_type == 'categorical': _result_df.loc[:, '_x_'] = _result_df.apply(lambda row: str(row[row['_vname_']]), axis=1) n = len(variable_names) facet_nrow = int(np.ceil(n / facet_ncol)) if vertical_spacing is None: vertical_spacing = 0.3 / facet_nrow if variable_type == 'numerical' else 0.05 if horizontal_spacing is None: horizontal_spacing = 0.05 if variable_type == 'numerical' else 0.1 plot_height = 78 + 71 + facet_nrow * (280 + 60) _result_df = _result_df.assign(_text_=_result_df.apply(lambda obs: plot.tooltip_text(obs), axis=1)) if variable_type == "numerical": m = len(_result_df[color].dropna().unique()) _result_df[color] = _result_df[color].astype(object) # prevent error when using pd.StringDtype fig = px.line(_result_df, x="_x_", y="_yhat_", color=color, facet_col="_vname_", line_group='_ids_', category_orders={"_vname_": list(variable_names)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'}, # , color: 'group'}, # hover_data={'_text_': True, '_yhat_': ':.3f', '_vname_': False, '_x_': False, color: False}, custom_data=['_text_'], facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, template="none", color_discrete_sequence=_theme.get_default_colors(m, 'line')) \ .update_traces(dict(line_width=size, opacity=alpha, hovertemplate="%{customdata[0]}<extra></extra>")) \ .update_xaxes({'matches': None, 'showticklabels': True, 'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}) \ .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True, 'range': min_max}) if show_observations: _points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy() fig_points = px.scatter(_points_df, x='_original_', y='_yhat_', facet_col='_vname_', category_orders={"_vname_": list(variable_names)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'}, custom_data=['_text_'], facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, color_discrete_sequence=["#371ea3"]) \ .update_traces(dict(marker_size=5*size, opacity=alpha), hovertemplate="%{customdata[0]}<extra></extra>") for _, value in enumerate(fig_points.data): fig.add_trace(value) fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, 'closest') else: if color=="_label_" and len(_result_df['_ids_'].unique()) > 1 and len(_result_df['_label_'].unique()) == 1: warnings.warn("'color' parameter changed to '_ids_' because there are multiple observations for one model.") color = '_ids_' elif color=="_label_" and len(_result_df['_ids_'].unique()) > len(_result_df['_label_'].unique()): # https://github.com/plotly/plotly.py/issues/2657 raise TypeError("Please pick one observation per label or change the `color` parameter.") m = len(_result_df[color].dropna().unique()) _result_df[color] = _result_df[color].astype(object) # prevent error when using pd.StringDtype _result_df = _result_df.assign(_diff_=lambda x: x['_yhat_'] - x['_original_yhat_']) fig = px.bar(_result_df, x="_diff_", y="_x_", color=color, facet_col="_vname_", category_orders={"_vname_": list(variable_names)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'}, # , color: 'group'}, # hover_data={'_yhat_': ':.3f', '_ids_': True, '_vname_': False, color: False}, custom_data=['_text_'], base="_original_yhat_", facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, template="none", color_discrete_sequence=_theme.get_default_colors(m, 'line'), barmode='group', orientation='h') \ .update_traces(dict(opacity=alpha), hovertemplate="%{customdata[0]}<extra></extra>") \ .update_yaxes({'matches': None, 'showticklabels': True, 'type': 'category', 'gridwidth': 2, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}) \ .update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True, 'range': min_max}) for _, bar in enumerate(fig.data): fig.add_vline(x=bar.base[0], layer='below', line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}) if show_observations: _points_df = _result_df.loc[_result_df['_original_'] == _result_df['_x_'], :].copy() fig_points = px.scatter(_points_df, x='_yhat_', y='_x_', facet_col='_vname_', category_orders={"_vname_": list(variable_names)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_ids_': 'id'}, custom_data=['_text_'], facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, color_discrete_sequence=["#371ea3"]) \ .update_traces(dict(marker_size=5*size, opacity=alpha), hovertemplate="%{customdata[0]}<extra></extra>") for _, value in enumerate(fig_points.data): fig.add_trace(value) fig = _theme.fig_update_bar_plot(fig, title, y_title, plot_height, 'closest') fig.update_layout(hoverlabel=dict(bgcolor='rgba(0,0,0,0.8)')) if show: fig.show(config=_theme.get_default_config()) else: return fig
class Shap (path='average', B=25, keep_distributions=False, processes=1, random_state=None)
-
Calculate predict-level variable attributions as Shapley Values
Parameters
path
:list
ofint
, optional- If specified, then attributions for this path will be plotted
(default is
'average'
, which plots attribution means forB
random paths). B
:int
, optional- Number of random paths to calculate variable attributions (default is
25
). keep_distributions
:bool
, optional- Save the distribution of partial predictions (default is
False
). processes
:int
, optional- Number of parallel processes to use in calculations. Iterated over
B
(default is1
, which means no parallel computation). random_state
:int
, optional- Set seed for random number generator (default is random seed).
Attributes
result
:pd.DataFrame
- Main result attribute of an explanation.
prediction
:float
- Prediction for
new_observation
. intercept
:float
- Average prediction for
data
. path
:list
ofint
or'average'
- Path for which the attributions will be plotted.
B
:int
- Number of random paths to calculate variable attributions.
keep_distributions
:bool
- Save the distribution of partial predictions.
yhats_distributions
:pd.DataFrame
orNone
- The distribution of partial predictions.
processes
:int
- Number of parallel processes to use in calculations. Iterated over
B
. random_state
:int
orNone
- Seed that was set for random number generator.
Notes
Expand source code Browse git
class Shap(Explanation): """Calculate predict-level variable attributions as Shapley Values Parameters ----------- path : list of int, optional If specified, then attributions for this path will be plotted (default is `'average'`, which plots attribution means for `B` random paths). B : int, optional Number of random paths to calculate variable attributions (default is `25`). keep_distributions : bool, optional Save the distribution of partial predictions (default is `False`). processes : int, optional Number of parallel processes to use in calculations. Iterated over `B` (default is `1`, which means no parallel computation). random_state : int, optional Set seed for random number generator (default is random seed). Attributes ----------- result : pd.DataFrame Main result attribute of an explanation. prediction : float Prediction for `new_observation`. intercept : float Average prediction for `data`. path : list of int or 'average' Path for which the attributions will be plotted. B : int Number of random paths to calculate variable attributions. keep_distributions : bool Save the distribution of partial predictions. yhats_distributions : pd.DataFrame or None The distribution of partial predictions. processes : int Number of parallel processes to use in calculations. Iterated over `B`. random_state : int or None Seed that was set for random number generator. Notes -------- - https://pbiecek.github.io/ema/shapley.html """ def __init__(self, path="average", B=25, keep_distributions=False, processes=1, random_state=None): _path = checks.check_path(path) _processess = checks.check_processes(processes) _random_state = checks.check_random_state(random_state) self.path = _path self.keep_distributions = keep_distributions self.B = B self.result = pd.DataFrame() self.yhats_distributions = None self.prediction = None self.intercept = None self.processes = _processess self.random_state = _random_state def _repr_html_(self): return self.result._repr_html_() def fit(self, explainer, new_observation): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.Series or np.ndarray An observation for which a prediction needs to be explained. Returns ----------- None """ _new_observation = checks.check_new_observation(new_observation, explainer) checks.check_columns_in_new_observation(_new_observation, explainer) self.result, self.prediction, self.intercept, self.yhats_distributions = utils.shap( explainer, _new_observation, self.path, self.keep_distributions, self.B, self.processes, self.random_state ) def plot(self, objects=None, baseline=None, max_vars=10, digits=3, rounding_function=np.around, bar_width=16, min_max=None, vcolors=None, title="Shapley Values", vertical_spacing=None, show=True): """Plot the Shapley Values explanation Parameters ----------- objects : Shap object or array_like of Shap objects Additional objects to plot in subplots (default is `None`). baseline: float, optional Starting x point for bars (default is average prediction). max_vars : int, optional Maximum number of variables that will be presented for for each subplot (default is `10`). digits : int, optional Number of decimal places (`np.around`) to round contributions. See `rounding_function` parameter (default is `3`). rounding_function : function, optional A function that will be used for rounding numbers (default is `np.around`). bar_width : float, optional Width of bars in px (default is `16`). min_max : 2-tuple of float, optional Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`). vcolors : 3-tuple of str, optional Color of bars (default is `["#8bdcbe", "#f05a71"]`). title : str, optional Title of the plot (default is `"Shapley Values"`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.2/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ # are there any other objects to plot? if objects is None: n = 1 _result_list = [self.result.loc[self.result['B'] == 0,].copy()] _intercept_list = [self.intercept] _prediction_list = [self.prediction] elif isinstance(objects, self.__class__): # allow for objects to be a single element n = 2 _result_list = [self.result.loc[self.result['B'] == 0,].copy(), objects.result.loc[objects.result['B'] == 0,].copy()] _intercept_list = [self.intercept, objects.intercept] _prediction_list = [self.prediction, objects.prediction] elif isinstance(objects, (list, tuple)): # objects as tuple or array n = len(objects) + 1 _result_list = [self.result.loc[self.result['B'] == 0,].copy()] _intercept_list = [self.intercept] _prediction_list = [self.prediction] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_list += [ob.result.loc[ob.result['B'] == 0,].copy()] _intercept_list += [ob.intercept] _prediction_list += [ob.prediction] else: _global_checks.global_raise_objects_class(objects, self.__class__) # TODO: add intercept and prediction list update for multi-class # deleted_indexes = [] # for i in range(n): # result = _result_list[i] # # if len(result['label'].unique()) > 1: # n += len(result['label'].unique()) - 1 # # add new data frames to list # _result_list += [v for k, v in result.groupby('label', sort=False)] # deleted_indexes += [i] # # _result_list = [j for i, j in enumerate(_result_list) if i not in deleted_indexes] model_names = [result.iloc[0, result.columns.get_loc("label")] for result in _result_list] if vertical_spacing is None: vertical_spacing = 0.2 / n fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='contribution', subplot_titles=model_names) plot_height = 78 + 71 if vcolors is None: vcolors = _theme.get_break_down_colors() if min_max is None: temp_min_max = [np.Inf, -np.Inf] else: temp_min_max = min_max for i, _result in enumerate(_result_list): if _result.shape[0] <= max_vars: m = _result.shape[0] else: m = max_vars + 1 if baseline is None: baseline = _intercept_list[i] prediction = _prediction_list[i] df = plot.prepare_data_for_shap_plot(_result, baseline, prediction, max_vars, rounding_function, digits) fig.add_shape( type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x", line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1 ) fig.add_bar( orientation="h", y=df['variable'].tolist(), x=df['contribution'].tolist(), textposition="outside", text=df['label_text'].tolist(), marker_color=[vcolors[int(c)] for c in df['sign'].tolist()], base=baseline, hovertext=df['tooltip_text'].tolist(), hoverinfo='text', hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'}, showlegend=False, row=i + 1, col=1 ) fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}, row=i + 1, col=1) fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}, row=i + 1, col=1) plot_height += m * bar_width + (m + 1) * bar_width / 4 if min_max is None: cum = df.contribution.values + baseline min_max_margin = cum.ptp() * 0.15 temp_min_max[0] = np.min([temp_min_max[0], cum.min() - min_max_margin]) temp_min_max[1] = np.max([temp_min_max[1], cum.max() + min_max_margin]) plot_height += (n - 1) * 70 fig.update_xaxes({'range': temp_min_max}) fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none", height=plot_height, margin={'t': 78, 'b': 71, 'r': 30}) if show: fig.show(config=_theme.get_default_config()) else: return fig
Ancestors
- dalex._explanation.Explanation
- abc.ABC
Methods
def fit(self, explainer, new_observation)
-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
explainer
:Explainer object
- Model wrapper created using the Explainer class.
new_observation
:pd.Series
ornp.ndarray
- An observation for which a prediction needs to be explained.
Returns
None
Expand source code Browse git
def fit(self, explainer, new_observation): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.Series or np.ndarray An observation for which a prediction needs to be explained. Returns ----------- None """ _new_observation = checks.check_new_observation(new_observation, explainer) checks.check_columns_in_new_observation(_new_observation, explainer) self.result, self.prediction, self.intercept, self.yhats_distributions = utils.shap( explainer, _new_observation, self.path, self.keep_distributions, self.B, self.processes, self.random_state )
def plot(self, objects=None, baseline=None, max_vars=10, digits=3, rounding_function=<function around>, bar_width=16, min_max=None, vcolors=None, title='Shapley Values', vertical_spacing=None, show=True)
-
Plot the Shapley Values explanation
Parameters
objects
:Shap object
orarray_like
ofShap objects
- Additional objects to plot in subplots (default is
None
). baseline
:float
, optional- Starting x point for bars (default is average prediction).
max_vars
:int
, optional- Maximum number of variables that will be presented for for each subplot
(default is
10
). digits
:int
, optional- Number of decimal places (
np.around
) to round contributions. Seerounding_function
parameter (default is3
). rounding_function
:function
, optional- A function that will be used for rounding numbers (default is
np.around
). bar_width
:float
, optional- Width of bars in px (default is
16
). min_max
:2-tuple
offloat
, optional- Range of OX axis (default is
[min-0.15*(max-min), max+0.15*(max-min)]
). vcolors
:3-tuple
ofstr
, optional- Color of bars (default is
["#8bdcbe", "#f05a71"]
). title
:str
, optional- Title of the plot (default is
"Shapley Values"
). vertical_spacing
:float <0, 1>
, optional- Ratio of vertical space between the plots (default is
0.2/number of rows
). show
:bool
, optionalTrue
shows the plot;False
returns the plotly Figure object that can be edited or saved using thewrite_image()
method (default isTrue
).
Returns
None
orplotly.graph_objects.Figure
- Return figure that can be edited or saved. See
show
parameter.
Expand source code Browse git
def plot(self, objects=None, baseline=None, max_vars=10, digits=3, rounding_function=np.around, bar_width=16, min_max=None, vcolors=None, title="Shapley Values", vertical_spacing=None, show=True): """Plot the Shapley Values explanation Parameters ----------- objects : Shap object or array_like of Shap objects Additional objects to plot in subplots (default is `None`). baseline: float, optional Starting x point for bars (default is average prediction). max_vars : int, optional Maximum number of variables that will be presented for for each subplot (default is `10`). digits : int, optional Number of decimal places (`np.around`) to round contributions. See `rounding_function` parameter (default is `3`). rounding_function : function, optional A function that will be used for rounding numbers (default is `np.around`). bar_width : float, optional Width of bars in px (default is `16`). min_max : 2-tuple of float, optional Range of OX axis (default is `[min-0.15*(max-min), max+0.15*(max-min)]`). vcolors : 3-tuple of str, optional Color of bars (default is `["#8bdcbe", "#f05a71"]`). title : str, optional Title of the plot (default is `"Shapley Values"`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.2/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ # are there any other objects to plot? if objects is None: n = 1 _result_list = [self.result.loc[self.result['B'] == 0,].copy()] _intercept_list = [self.intercept] _prediction_list = [self.prediction] elif isinstance(objects, self.__class__): # allow for objects to be a single element n = 2 _result_list = [self.result.loc[self.result['B'] == 0,].copy(), objects.result.loc[objects.result['B'] == 0,].copy()] _intercept_list = [self.intercept, objects.intercept] _prediction_list = [self.prediction, objects.prediction] elif isinstance(objects, (list, tuple)): # objects as tuple or array n = len(objects) + 1 _result_list = [self.result.loc[self.result['B'] == 0,].copy()] _intercept_list = [self.intercept] _prediction_list = [self.prediction] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_list += [ob.result.loc[ob.result['B'] == 0,].copy()] _intercept_list += [ob.intercept] _prediction_list += [ob.prediction] else: _global_checks.global_raise_objects_class(objects, self.__class__) # TODO: add intercept and prediction list update for multi-class # deleted_indexes = [] # for i in range(n): # result = _result_list[i] # # if len(result['label'].unique()) > 1: # n += len(result['label'].unique()) - 1 # # add new data frames to list # _result_list += [v for k, v in result.groupby('label', sort=False)] # deleted_indexes += [i] # # _result_list = [j for i, j in enumerate(_result_list) if i not in deleted_indexes] model_names = [result.iloc[0, result.columns.get_loc("label")] for result in _result_list] if vertical_spacing is None: vertical_spacing = 0.2 / n fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='contribution', subplot_titles=model_names) plot_height = 78 + 71 if vcolors is None: vcolors = _theme.get_break_down_colors() if min_max is None: temp_min_max = [np.Inf, -np.Inf] else: temp_min_max = min_max for i, _result in enumerate(_result_list): if _result.shape[0] <= max_vars: m = _result.shape[0] else: m = max_vars + 1 if baseline is None: baseline = _intercept_list[i] prediction = _prediction_list[i] df = plot.prepare_data_for_shap_plot(_result, baseline, prediction, max_vars, rounding_function, digits) fig.add_shape( type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x", line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1 ) fig.add_bar( orientation="h", y=df['variable'].tolist(), x=df['contribution'].tolist(), textposition="outside", text=df['label_text'].tolist(), marker_color=[vcolors[int(c)] for c in df['sign'].tolist()], base=baseline, hovertext=df['tooltip_text'].tolist(), hoverinfo='text', hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'}, showlegend=False, row=i + 1, col=1 ) fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}, row=i + 1, col=1) fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}, row=i + 1, col=1) plot_height += m * bar_width + (m + 1) * bar_width / 4 if min_max is None: cum = df.contribution.values + baseline min_max_margin = cum.ptp() * 0.15 temp_min_max[0] = np.min([temp_min_max[0], cum.min() - min_max_margin]) temp_min_max[1] = np.max([temp_min_max[1], cum.max() + min_max_margin]) plot_height += (n - 1) * 70 fig.update_xaxes({'range': temp_min_max}) fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none", height=plot_height, margin={'t': 78, 'b': 71, 'r': 30}) if show: fig.show(config=_theme.get_default_config()) else: return fig