Module dalex.model_explanations
Expand source code Browse git
from ._aggregated_profiles.object import AggregatedProfiles
from ._model_performance.object import ModelPerformance
from ._variable_importance.object import VariableImportance
from ._residual_diagnostics import ResidualDiagnostics
__all__ = [
"ModelPerformance",
"VariableImportance",
"AggregatedProfiles",
"ResidualDiagnostics"
]
Classes
class AggregatedProfiles (type='partial', variables=None, variable_type='numerical', groups=None, span=0.25, center=True, random_state=None)
-
Calculate model-level variable profiles as Partial or Accumulated Dependence
- Partial Dependence Profile (average across Ceteris Paribus Profiles),
- Individual Conditional Expectation (local weighted average across CP Profiles),
- Accumulated Local Effects (cumulated average local changes in CP Profiles).
Parameters
type
:{'partial', 'accumulated', 'conditional'}
- Type of model profiles
(default is
'partial'
for Partial Dependence Profiles). variables
:str
orarray_like
ofstr
, optional- Variables for which the profiles will be calculated
(default is
None
, which means all of the variables). variable_type
:{'numerical', 'categorical'}
- Calculate the profiles for numerical or categorical variables
(default is
'numerical'
). groups
:str
orarray_like
ofstr
, optional- Names of categorical variables that will be used for profile grouping
(default is
None
, which means no grouping). span
:float
, optional- Smoothing coefficient used as sd for gaussian kernel (default is
0.25
). center
:bool
, optional- Theoretically Accumulated Profiles start at
0
, but are centered to compare them with Partial Dependence Profiles (default isTrue
, which means center around the average y_hat calculated on the data sample). random_state
:int
, optional- Set seed for random number generator (default is random seed).
Attributes
result
:pd.DataFrame
- Main result attribute of an explanation.
mean_prediction
:float
- Average prediction for sampled
data
(usingN
). raw_profiles
:pd.DataFrame
orNone
- Saved CeterisParibus object.
NOTE:
None
if more objects were passed to thefit
method. type
:{'partial', 'accumulated', 'conditional'}
- Type of model profiles.
variables
:array_like
ofstr
orNone
- Variables for which the profiles will be calculated
variable_type
:{'numerical', 'categorical'}
- Calculate the profiles for numerical or categorical variables.
groups
:str
orarray_like
ofstr
orNone
- Names of categorical variables that will be used for profile grouping.
span
:float
- Smoothing coefficient used as sd for gaussian kernel.
center
:bool
- Theoretically Accumulated Profiles start at
0
, but are centered to compare them with Partial Dependence Profiles (default isTrue
, which means center around the average y_hat calculated on the data sample). random_state
:int
orNone
- Set seed for random number generator.
Notes
Expand source code Browse git
class AggregatedProfiles(Explanation): """Calculate model-level variable profiles as Partial or Accumulated Dependence - Partial Dependence Profile (average across Ceteris Paribus Profiles), - Individual Conditional Expectation (local weighted average across CP Profiles), - Accumulated Local Effects (cumulated average local changes in CP Profiles). Parameters ----------- type : {'partial', 'accumulated', 'conditional'} Type of model profiles (default is `'partial'` for Partial Dependence Profiles). variables : str or array_like of str, optional Variables for which the profiles will be calculated (default is `None`, which means all of the variables). variable_type : {'numerical', 'categorical'} Calculate the profiles for numerical or categorical variables (default is `'numerical'`). groups : str or array_like of str, optional Names of categorical variables that will be used for profile grouping (default is `None`, which means no grouping). span : float, optional Smoothing coefficient used as sd for gaussian kernel (default is `0.25`). center : bool, optional Theoretically Accumulated Profiles start at `0`, but are centered to compare them with Partial Dependence Profiles (default is `True`, which means center around the average y_hat calculated on the data sample). random_state : int, optional Set seed for random number generator (default is random seed). Attributes ----------- result : pd.DataFrame Main result attribute of an explanation. mean_prediction : float Average prediction for sampled `data` (using `N`). raw_profiles : pd.DataFrame or None Saved CeterisParibus object. NOTE: `None` if more objects were passed to the `fit` method. type : {'partial', 'accumulated', 'conditional'} Type of model profiles. variables : array_like of str or None Variables for which the profiles will be calculated variable_type : {'numerical', 'categorical'} Calculate the profiles for numerical or categorical variables. groups : str or array_like of str or None Names of categorical variables that will be used for profile grouping. span : float Smoothing coefficient used as sd for gaussian kernel. center : bool Theoretically Accumulated Profiles start at `0`, but are centered to compare them with Partial Dependence Profiles (default is `True`, which means center around the average y_hat calculated on the data sample). random_state : int or None Set seed for random number generator. Notes -------- - https://pbiecek.github.io/ema/partialDependenceProfiles.html - https://pbiecek.github.io/ema/accumulatedLocalProfiles.html """ def __init__(self, type='partial', variables=None, variable_type='numerical', groups=None, span=0.25, center=True, random_state=None): checks.check_variable_type(variable_type) _variables = checks.check_variables(variables) _groups = checks.check_groups(groups) self.variable_type = variable_type self.groups = _groups self.type = type self.variables = _variables self.span = span self.center = center self.result = pd.DataFrame() self.mean_prediction = None self.raw_profiles = None self.random_state = random_state def _repr_html_(self): return self.result._repr_html_() def fit(self, ceteris_paribus, verbose=True): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- ceteris_paribus : CeterisParibus object or array_like of CeterisParibus objects Profile objects to aggregate. verbose : bool, optional Print tqdm progress bar (default is `True`). Returns ----------- None """ # are there any other cp? from dalex.predict_explanations import CeterisParibus if isinstance(ceteris_paribus, CeterisParibus): # allow for ceteris_paribus to be a single element all_profiles = ceteris_paribus.result.copy() all_observations = ceteris_paribus.new_observation.copy() self.raw_profiles = deepcopy(ceteris_paribus) elif isinstance(ceteris_paribus, (list, tuple)): # ceteris_paribus as tuple or array all_profiles = None all_observations = None for cp in ceteris_paribus: _global_checks.global_check_object_class(cp, CeterisParibus) all_profiles = pd.concat([all_profiles, cp.result.copy()]) all_observations = pd.concat([all_observations, cp.new_observation.copy()]) else: _global_checks.global_raise_objects_class(ceteris_paribus, CeterisParibus) all_profiles, vnames = utils.prepare_numerical_categorical(all_profiles, self.variables, self.variable_type) # select only suitable variables all_profiles = all_profiles.loc[all_profiles['_vname_'].isin(vnames), :] all_profiles = utils.prepare_x(all_profiles, self.variable_type) self.mean_prediction = all_observations['_yhat_'].mean() self.result = utils.aggregate_profiles(all_profiles, self.mean_prediction, self.type, self.groups, self.center, self.span, verbose) def plot(self, objects=None, geom='aggregates', variables=None, center=True, size=2, alpha=1, color='_label_', facet_ncol=2, title="Aggregated Profiles", y_title='prediction', horizontal_spacing=0.05, vertical_spacing=None, show=True): """Plot the Aggregated Profiles explanation Parameters ----------- objects : AggregatedProfiles object or array_like of AggregatedProfiles objects Additional objects to plot in subplots (default is `None`). geom : {'aggregates', 'profiles', 'bars'} If `'profiles'` then raw profiles will be plotted in the background, 'bars' overrides the `_x_` column type and uses barplots for categorical data (default is `'aggregates'`, which means plot only aggregated profiles). NOTE: It is useful to use small values of the `N` parameter in object creation before using `'profiles'`, because of plot performance and clarity (e.g. `100`). variables : str or array_like of str, optional Variables for which the profiles will be calculated (default is `None`, which means all of the variables). center : bool, optional Theoretically Accumulated Profiles start at `0`, but are centered to compare them with Partial Dependence Profiles (default is `True`, which means center around the average y_hat calculated on the data sample). size : float, optional Width of lines in px (default is `2`). alpha : float <0, 1>, optional Opacity of lines (default is `1`). color : str, optional Variable name used for grouping (default is `'_label_'`, which groups by models). facet_ncol : int, optional Number of columns on the plot grid (default is `2`). title : str, optional Title of the plot (default is `"Aggregated Profiles"`). y_title : str, optional Title of the y axis (default is `"prediction"`). horizontal_spacing : float <0, 1>, optional Ratio of horizontal space between the plots (default is `0.05`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.3/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ if geom not in ("aggregates", "profiles", "bars"): raise TypeError("geom should be one of {'aggregates', 'profiles', 'bars'}") if isinstance(variables, str): variables = (variables,) # are there any other objects to plot? if objects is None: _result_df = self.result.assign(_mp_=self.mean_prediction if center else 0) elif isinstance(objects, self.__class__): # allow for objects to be a single element _result_df = pd.concat([self.result.assign(_mp_=self.mean_prediction if center else 0), objects.result.assign(_mp_=objects.mean_prediction if center else 0)]) elif isinstance(objects, (list, tuple)): # objects as tuple or array _result_df = self.result.assign(_mp_=self.mean_prediction if center else 0) for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_df = pd.concat([_result_df, ob.result.assign(_mp_=ob.mean_prediction if center else 0)]) else: _global_checks.global_raise_objects_class(objects, self.__class__) # variables to use all_variables = _result_df['_vname_'].dropna().unique().tolist() if variables is not None: all_variables = _global_utils.intersect_unsorted(variables, all_variables) if len(all_variables) == 0: raise TypeError("variables do not overlap with " + ''.join(variables)) _result_df = _result_df.loc[_result_df['_vname_'].isin(all_variables), :] # calculate y axis range to allow for fixedrange True dl = _result_df['_yhat_'].to_numpy() min_max_margin = dl.ptp() * 0.10 min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin] is_x_numeric = False if geom == 'bars' else pd.api.types.is_numeric_dtype(_result_df['_x_']) n = len(all_variables) facet_nrow = int(np.ceil(n / facet_ncol)) if vertical_spacing is None: vertical_spacing = 0.3 / facet_nrow plot_height = 78 + 71 + facet_nrow * (280 + 60) hovermode, render_mode = 'x unified', 'svg' # color = '_groups_' doesnt make much sense for multiple AP objects m = len(_result_df[color].dropna().unique()) if is_x_numeric: if geom == 'profiles' and self.raw_profiles is not None: render_mode = 'webgl' fig = px.line(_result_df, x="_x_", y="_yhat_", color=color, facet_col="_vname_", category_orders={"_vname_": list(all_variables)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'}, # , color: 'group'}, hover_name=color, hover_data={'_yhat_': ':.3f', '_mp_': ':.3f', color: False, '_vname_': False, '_x_': False}, facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, template="none", render_mode=render_mode, color_discrete_sequence=_theme.get_default_colors(m, 'line')) \ .update_traces(dict(line_width=size, opacity=alpha)) \ .update_xaxes({'matches': None, 'showticklabels': True, 'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}) \ .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True, 'range': min_max}) if geom == 'profiles' and self.raw_profiles is not None: fig.update_traces(dict(line_width=2*size, opacity=1)) fig_cp = self.raw_profiles.plot(variables=list(all_variables), facet_ncol=facet_ncol, show_observations=False, show=False) \ .update_traces(dict(line_width=1, opacity=0.5, line_color='#ceced9')) for _, value in enumerate(fig.data): fig_cp.add_trace(value) hovermode = False fig = fig_cp else: _result_df = _result_df.assign(_diff_=lambda x: x['_yhat_'] - x['_mp_']) mp_format = ':.3f' if not center: min_max = [np.min([min_max[0], 0]), np.max([min_max[1], 0])] mp_format = False fig = px.bar(_result_df, x="_x_", y="_diff_", color="_label_", facet_col="_vname_", category_orders={"_vname_": list(all_variables)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'}, # , color: 'group'}, hover_name=color, base="_mp_", hover_data={'_yhat_': ':.3f', '_mp_': mp_format, '_diff_': False, color: False, '_vname_': False, '_x_': False}, facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, template="none", color_discrete_sequence=_theme.get_default_colors(m, 'line'), # bar was forgotten barmode='group') \ .update_xaxes({'matches': None, 'showticklabels': True, 'type': 'category', 'gridwidth': 2, 'automargin': True, # autorange="reversed" 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}) \ .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True, 'range': min_max}) for _, bar in enumerate(fig.data): fig.add_hline(y=bar.base[0], layer='below', line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}) fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, hovermode) if show: fig.show(config=_theme.get_default_config()) else: return fig
Ancestors
- dalex._explanation.Explanation
- abc.ABC
Methods
def fit(self, ceteris_paribus, verbose=True)
-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
ceteris_paribus
:CeterisParibus object
orarray_like
ofCeterisParibus objects
- Profile objects to aggregate.
verbose
:bool
, optional- Print tqdm progress bar (default is
True
).
Returns
None
Expand source code Browse git
def fit(self, ceteris_paribus, verbose=True): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- ceteris_paribus : CeterisParibus object or array_like of CeterisParibus objects Profile objects to aggregate. verbose : bool, optional Print tqdm progress bar (default is `True`). Returns ----------- None """ # are there any other cp? from dalex.predict_explanations import CeterisParibus if isinstance(ceteris_paribus, CeterisParibus): # allow for ceteris_paribus to be a single element all_profiles = ceteris_paribus.result.copy() all_observations = ceteris_paribus.new_observation.copy() self.raw_profiles = deepcopy(ceteris_paribus) elif isinstance(ceteris_paribus, (list, tuple)): # ceteris_paribus as tuple or array all_profiles = None all_observations = None for cp in ceteris_paribus: _global_checks.global_check_object_class(cp, CeterisParibus) all_profiles = pd.concat([all_profiles, cp.result.copy()]) all_observations = pd.concat([all_observations, cp.new_observation.copy()]) else: _global_checks.global_raise_objects_class(ceteris_paribus, CeterisParibus) all_profiles, vnames = utils.prepare_numerical_categorical(all_profiles, self.variables, self.variable_type) # select only suitable variables all_profiles = all_profiles.loc[all_profiles['_vname_'].isin(vnames), :] all_profiles = utils.prepare_x(all_profiles, self.variable_type) self.mean_prediction = all_observations['_yhat_'].mean() self.result = utils.aggregate_profiles(all_profiles, self.mean_prediction, self.type, self.groups, self.center, self.span, verbose)
def plot(self, objects=None, geom='aggregates', variables=None, center=True, size=2, alpha=1, color='_label_', facet_ncol=2, title='Aggregated Profiles', y_title='prediction', horizontal_spacing=0.05, vertical_spacing=None, show=True)
-
Plot the Aggregated Profiles explanation
Parameters
objects
:AggregatedProfiles object
orarray_like
ofAggregatedProfiles objects
- Additional objects to plot in subplots (default is
None
). geom
:{'aggregates', 'profiles', 'bars'}
- If
'profiles'
then raw profiles will be plotted in the background, 'bars' overrides the_x_
column type and uses barplots for categorical data (default is'aggregates'
, which means plot only aggregated profiles). NOTE: It is useful to use small values of theN
parameter in object creation before using'profiles'
, because of plot performance and clarity (e.g.100
). variables
:str
orarray_like
ofstr
, optional- Variables for which the profiles will be calculated
(default is
None
, which means all of the variables). center
:bool
, optional- Theoretically Accumulated Profiles start at
0
, but are centered to compare them with Partial Dependence Profiles (default isTrue
, which means center around the average y_hat calculated on the data sample). size
:float
, optional- Width of lines in px (default is
2
). alpha
:float <0, 1>
, optional- Opacity of lines (default is
1
). color
:str
, optional- Variable name used for grouping
(default is
'_label_'
, which groups by models). facet_ncol
:int
, optional- Number of columns on the plot grid (default is
2
). title
:str
, optional- Title of the plot (default is
"Aggregated Profiles"
). y_title
:str
, optional- Title of the y axis (default is
"prediction"
). horizontal_spacing
:float <0, 1>
, optional- Ratio of horizontal space between the plots (default is
0.05
). vertical_spacing
:float <0, 1>
, optional- Ratio of vertical space between the plots (default is
0.3/number of rows
). show
:bool
, optionalTrue
shows the plot;False
returns the plotly Figure object that can be edited or saved using thewrite_image()
method (default isTrue
).
Returns
None
orplotly.graph_objects.Figure
- Return figure that can be edited or saved. See
show
parameter.
Expand source code Browse git
def plot(self, objects=None, geom='aggregates', variables=None, center=True, size=2, alpha=1, color='_label_', facet_ncol=2, title="Aggregated Profiles", y_title='prediction', horizontal_spacing=0.05, vertical_spacing=None, show=True): """Plot the Aggregated Profiles explanation Parameters ----------- objects : AggregatedProfiles object or array_like of AggregatedProfiles objects Additional objects to plot in subplots (default is `None`). geom : {'aggregates', 'profiles', 'bars'} If `'profiles'` then raw profiles will be plotted in the background, 'bars' overrides the `_x_` column type and uses barplots for categorical data (default is `'aggregates'`, which means plot only aggregated profiles). NOTE: It is useful to use small values of the `N` parameter in object creation before using `'profiles'`, because of plot performance and clarity (e.g. `100`). variables : str or array_like of str, optional Variables for which the profiles will be calculated (default is `None`, which means all of the variables). center : bool, optional Theoretically Accumulated Profiles start at `0`, but are centered to compare them with Partial Dependence Profiles (default is `True`, which means center around the average y_hat calculated on the data sample). size : float, optional Width of lines in px (default is `2`). alpha : float <0, 1>, optional Opacity of lines (default is `1`). color : str, optional Variable name used for grouping (default is `'_label_'`, which groups by models). facet_ncol : int, optional Number of columns on the plot grid (default is `2`). title : str, optional Title of the plot (default is `"Aggregated Profiles"`). y_title : str, optional Title of the y axis (default is `"prediction"`). horizontal_spacing : float <0, 1>, optional Ratio of horizontal space between the plots (default is `0.05`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.3/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ if geom not in ("aggregates", "profiles", "bars"): raise TypeError("geom should be one of {'aggregates', 'profiles', 'bars'}") if isinstance(variables, str): variables = (variables,) # are there any other objects to plot? if objects is None: _result_df = self.result.assign(_mp_=self.mean_prediction if center else 0) elif isinstance(objects, self.__class__): # allow for objects to be a single element _result_df = pd.concat([self.result.assign(_mp_=self.mean_prediction if center else 0), objects.result.assign(_mp_=objects.mean_prediction if center else 0)]) elif isinstance(objects, (list, tuple)): # objects as tuple or array _result_df = self.result.assign(_mp_=self.mean_prediction if center else 0) for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_df = pd.concat([_result_df, ob.result.assign(_mp_=ob.mean_prediction if center else 0)]) else: _global_checks.global_raise_objects_class(objects, self.__class__) # variables to use all_variables = _result_df['_vname_'].dropna().unique().tolist() if variables is not None: all_variables = _global_utils.intersect_unsorted(variables, all_variables) if len(all_variables) == 0: raise TypeError("variables do not overlap with " + ''.join(variables)) _result_df = _result_df.loc[_result_df['_vname_'].isin(all_variables), :] # calculate y axis range to allow for fixedrange True dl = _result_df['_yhat_'].to_numpy() min_max_margin = dl.ptp() * 0.10 min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin] is_x_numeric = False if geom == 'bars' else pd.api.types.is_numeric_dtype(_result_df['_x_']) n = len(all_variables) facet_nrow = int(np.ceil(n / facet_ncol)) if vertical_spacing is None: vertical_spacing = 0.3 / facet_nrow plot_height = 78 + 71 + facet_nrow * (280 + 60) hovermode, render_mode = 'x unified', 'svg' # color = '_groups_' doesnt make much sense for multiple AP objects m = len(_result_df[color].dropna().unique()) if is_x_numeric: if geom == 'profiles' and self.raw_profiles is not None: render_mode = 'webgl' fig = px.line(_result_df, x="_x_", y="_yhat_", color=color, facet_col="_vname_", category_orders={"_vname_": list(all_variables)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'}, # , color: 'group'}, hover_name=color, hover_data={'_yhat_': ':.3f', '_mp_': ':.3f', color: False, '_vname_': False, '_x_': False}, facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, template="none", render_mode=render_mode, color_discrete_sequence=_theme.get_default_colors(m, 'line')) \ .update_traces(dict(line_width=size, opacity=alpha)) \ .update_xaxes({'matches': None, 'showticklabels': True, 'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}) \ .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True, 'range': min_max}) if geom == 'profiles' and self.raw_profiles is not None: fig.update_traces(dict(line_width=2*size, opacity=1)) fig_cp = self.raw_profiles.plot(variables=list(all_variables), facet_ncol=facet_ncol, show_observations=False, show=False) \ .update_traces(dict(line_width=1, opacity=0.5, line_color='#ceced9')) for _, value in enumerate(fig.data): fig_cp.add_trace(value) hovermode = False fig = fig_cp else: _result_df = _result_df.assign(_diff_=lambda x: x['_yhat_'] - x['_mp_']) mp_format = ':.3f' if not center: min_max = [np.min([min_max[0], 0]), np.max([min_max[1], 0])] mp_format = False fig = px.bar(_result_df, x="_x_", y="_diff_", color="_label_", facet_col="_vname_", category_orders={"_vname_": list(all_variables)}, labels={'_yhat_': 'prediction', '_label_': 'label', '_mp_': 'mean_prediction'}, # , color: 'group'}, hover_name=color, base="_mp_", hover_data={'_yhat_': ':.3f', '_mp_': mp_format, '_diff_': False, color: False, '_vname_': False, '_x_': False}, facet_col_wrap=facet_ncol, facet_row_spacing=vertical_spacing, facet_col_spacing=horizontal_spacing, template="none", color_discrete_sequence=_theme.get_default_colors(m, 'line'), # bar was forgotten barmode='group') \ .update_xaxes({'matches': None, 'showticklabels': True, 'type': 'category', 'gridwidth': 2, 'automargin': True, # autorange="reversed" 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}) \ .update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True, 'range': min_max}) for _, bar in enumerate(fig.data): fig.add_hline(y=bar.base[0], layer='below', line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}) fig = _theme.fig_update_line_plot(fig, title, y_title, plot_height, hovermode) if show: fig.show(config=_theme.get_default_config()) else: return fig
class ModelPerformance (model_type, cutoff=0.5)
-
Calculate model-level model performance measures
Parameters
model_type
:{'regression', 'classification'}
- Model task type that is used to choose the proper performance measures.
cutoff
:float
, optional- Cutoff for predictions in classification models. Needed for measures like
recall, precision, acc, f1 (default is
0.5
).
Attributes
result
:pd.DataFrame
- Main result attribute of an explanation.
residuals
:pd.DataFrame
- Residuals for
data
. model_type
:{'regression', 'classification'}
- Model task type that is used to choose the proper performance measures.
cutoff
:float
- Cutoff for predictions in classification models.
Notes
Expand source code Browse git
class ModelPerformance(Explanation): """Calculate model-level model performance measures Parameters ----------- model_type : {'regression', 'classification'} Model task type that is used to choose the proper performance measures. cutoff : float, optional Cutoff for predictions in classification models. Needed for measures like recall, precision, acc, f1 (default is `0.5`). Attributes ----------- result : pd.DataFrame Main result attribute of an explanation. residuals : pd.DataFrame Residuals for `data`. model_type : {'regression', 'classification'} Model task type that is used to choose the proper performance measures. cutoff : float Cutoff for predictions in classification models. Notes -------- - https://pbiecek.github.io/ema/modelPerformance.html """ def __init__(self, model_type, cutoff=0.5): self.cutoff = cutoff self.model_type = model_type self.result = pd.DataFrame() self.residuals = None def _repr_html_(self): return self.result._repr_html_() def fit(self, explainer): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. Returns ----------- None """ if explainer.y_hat is not None: y_pred = explainer.y_hat else: y_pred = explainer.predict(explainer.data) if explainer.residuals is not None: _residuals = explainer.residuals else: _residuals = explainer.residual(explainer.data, explainer.y) y_true = explainer.y if self.model_type == 'regression': _mse = utils.mse(y_pred, y_true) _rmse = utils.rmse(y_pred, y_true) _r2 = utils.r2(y_pred, y_true) _mae = utils.mae(y_pred, y_true) _mad = utils.mad(y_pred, y_true) self.result = pd.DataFrame( { 'mse': [_mse], 'rmse': [_rmse], 'r2': [_r2], 'mae': [_mae], 'mad': [_mad] }, index=[explainer.label]) elif self.model_type == 'classification': tp = ((y_true == 1) * (y_pred >= self.cutoff)).sum() fp = ((y_true == 0) * (y_pred >= self.cutoff)).sum() tn = ((y_true == 0) * (y_pred < self.cutoff)).sum() fn = ((y_true == 1) * (y_pred < self.cutoff)).sum() _recall = utils.recall(tp, fp, tn, fn) _precision = utils.precision(tp, fp, tn, fn) _f1 = utils.f1(tp, fp, tn, fn) _accuracy = utils.accuracy(tp, fp, tn, fn) _auc = utils.auc(y_pred, y_true) self.result = pd.DataFrame({ 'recall': [_recall], 'precision': [_precision], 'f1': [_f1], 'accuracy': [_accuracy], 'auc': [_auc] }, index=[explainer.label]) else: raise ValueError("'model_type' must be 'regression' or 'classification'") _residuals = pd.DataFrame({ 'y_hat': y_pred, 'y': y_true, 'residuals': _residuals, 'label': explainer.label }) self.residuals = _residuals def plot(self, objects=None, geom="ecdf", title=None, show=False): """Plot the Model Performance explanation Parameters ----------- objects : ModelPerformance object or array_like of ModelPerformance objects Additional objects to plot (default is `None`). geom: {'ecdf', 'roc', 'lift'} Type of plot determines how residuals shall be summarized. title : str, optional Title of the plot (default depends on the `type` attribute). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ if geom not in ("ecdf", "roc", "lift"): raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}") # are there any other objects to plot? if objects is None: _df_list = [self.residuals.copy()] elif isinstance(objects, self.__class__): # allow for objects to be a single element _df_list = [self.residuals.copy(), objects.residuals.copy()] elif isinstance(objects, (list, tuple)): # objects as tuple or array _df_list = [self.residuals.copy()] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _df_list += [ob.residuals.copy()] else: _global_checks.global_raise_objects_class(objects, self.__class__) colors = _theme.get_default_colors(len(_df_list), 'line') if geom == 'ecdf': fig = plot.plot_ecdf(_df_list, colors, title) elif geom == 'roc': fig = plot.plot_roc(_df_list, colors, title) elif geom == 'lift': fig = plot.plot_lift(_df_list, colors, title) else: raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}") if show: fig.show(config=_theme.get_default_config()) else: return fig
Ancestors
- dalex._explanation.Explanation
- abc.ABC
Methods
def fit(self, explainer)
-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
explainer
:Explainer object
- Model wrapper created using the Explainer class.
Returns
None
Expand source code Browse git
def fit(self, explainer): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. Returns ----------- None """ if explainer.y_hat is not None: y_pred = explainer.y_hat else: y_pred = explainer.predict(explainer.data) if explainer.residuals is not None: _residuals = explainer.residuals else: _residuals = explainer.residual(explainer.data, explainer.y) y_true = explainer.y if self.model_type == 'regression': _mse = utils.mse(y_pred, y_true) _rmse = utils.rmse(y_pred, y_true) _r2 = utils.r2(y_pred, y_true) _mae = utils.mae(y_pred, y_true) _mad = utils.mad(y_pred, y_true) self.result = pd.DataFrame( { 'mse': [_mse], 'rmse': [_rmse], 'r2': [_r2], 'mae': [_mae], 'mad': [_mad] }, index=[explainer.label]) elif self.model_type == 'classification': tp = ((y_true == 1) * (y_pred >= self.cutoff)).sum() fp = ((y_true == 0) * (y_pred >= self.cutoff)).sum() tn = ((y_true == 0) * (y_pred < self.cutoff)).sum() fn = ((y_true == 1) * (y_pred < self.cutoff)).sum() _recall = utils.recall(tp, fp, tn, fn) _precision = utils.precision(tp, fp, tn, fn) _f1 = utils.f1(tp, fp, tn, fn) _accuracy = utils.accuracy(tp, fp, tn, fn) _auc = utils.auc(y_pred, y_true) self.result = pd.DataFrame({ 'recall': [_recall], 'precision': [_precision], 'f1': [_f1], 'accuracy': [_accuracy], 'auc': [_auc] }, index=[explainer.label]) else: raise ValueError("'model_type' must be 'regression' or 'classification'") _residuals = pd.DataFrame({ 'y_hat': y_pred, 'y': y_true, 'residuals': _residuals, 'label': explainer.label }) self.residuals = _residuals
def plot(self, objects=None, geom='ecdf', title=None, show=False)
-
Plot the Model Performance explanation
Parameters
objects
:ModelPerformance object
orarray_like
ofModelPerformance objects
- Additional objects to plot (default is
None
). geom
:{'ecdf', 'roc', 'lift'}
- Type of plot determines how residuals shall be summarized.
title
:str
, optional- Title of the plot (default depends on the
type
attribute). show
:bool
, optionalTrue
shows the plot;False
returns the plotly Figure object that can be edited or saved using thewrite_image()
method (default isTrue
).
Returns
None
orplotly.graph_objects.Figure
- Return figure that can be edited or saved. See
show
parameter.
Expand source code Browse git
def plot(self, objects=None, geom="ecdf", title=None, show=False): """Plot the Model Performance explanation Parameters ----------- objects : ModelPerformance object or array_like of ModelPerformance objects Additional objects to plot (default is `None`). geom: {'ecdf', 'roc', 'lift'} Type of plot determines how residuals shall be summarized. title : str, optional Title of the plot (default depends on the `type` attribute). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ if geom not in ("ecdf", "roc", "lift"): raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}") # are there any other objects to plot? if objects is None: _df_list = [self.residuals.copy()] elif isinstance(objects, self.__class__): # allow for objects to be a single element _df_list = [self.residuals.copy(), objects.residuals.copy()] elif isinstance(objects, (list, tuple)): # objects as tuple or array _df_list = [self.residuals.copy()] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _df_list += [ob.residuals.copy()] else: _global_checks.global_raise_objects_class(objects, self.__class__) colors = _theme.get_default_colors(len(_df_list), 'line') if geom == 'ecdf': fig = plot.plot_ecdf(_df_list, colors, title) elif geom == 'roc': fig = plot.plot_roc(_df_list, colors, title) elif geom == 'lift': fig = plot.plot_lift(_df_list, colors, title) else: raise TypeError("geom should be one of {'ecdf', 'roc', 'lift'}") if show: fig.show(config=_theme.get_default_config()) else: return fig
class ResidualDiagnostics (variables=None)
-
Calculate model-level residuals diagnostics
Parameters
variables
:str
orarray_like
ofstr
, optional- Variables for which the profiles will be calculated
(default is
None
, which means all of the variables).
Attributes
result
:pd.DataFrame
- Main result attribute of an explanation.
variables
:array_like
ofstr
orNone
- Variables for which the profiles will be calculated.
Notes
Expand source code Browse git
class ResidualDiagnostics(Explanation): """Calculate model-level residuals diagnostics Parameters ----------- variables : str or array_like of str, optional Variables for which the profiles will be calculated (default is `None`, which means all of the variables). Attributes ----------- result : pd.DataFrame Main result attribute of an explanation. variables : array_like of str or None Variables for which the profiles will be calculated. Notes -------- - https://pbiecek.github.io/ema/residualDiagnostic.html """ def __init__(self, variables=None): _variables = checks.check_variables(variables) self.result = pd.DataFrame() self.variables = _variables def _repr_html_(self): return self.result._repr_html_() def fit(self, explainer): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. Returns ----------- None """ result = explainer.data.copy() # if variables = NULL then all variables are added # otherwise only selected if self.variables is not None: result = result.loc[:, _global_utils.intersect_unsorted(self.variables, result.columns)] # is there target if explainer.y is not None: result = result.assign(y=explainer.y) # are there predictions - add y_hat to the Explainer for the future if explainer.y_hat is None: explainer.y_hat = explainer.predict(explainer.data) # are there residuals - add residuals to the Explainer for the future if explainer.residuals is None: explainer.residuals = explainer.residual(explainer.data, explainer.y) self.result = result.assign( y_hat=explainer.y_hat, residuals=explainer.residuals, abs_residuals=np.abs(explainer.residuals), label=explainer.label, ids=np.arange(result.shape[0])+1 ) def plot(self, objects=None, variable="y_hat", yvariable="residuals", smooth=True, line_width=2, marker_size=3, title="Residual Diagnostics", N=50000, show=True): """Plot the Residual Diagnostics explanation Parameters ---------- objects : ResidualDiagnostics object or array_like of ResidualDiagnostics objects Additional objects to plot (default is `None`). variable : str, optional Name of the variable from the `result` attribute to appear on the OX axis (default is `'y_hat'`). yvariable : str, optional Name of the variable from the `result` attribute to appear on the OY axis (default is `'residuals'`). smooth : bool, optional Add the smooth line (default is `True`). line_width : float, optional Width of lines in px (default is `2`). marker_size : float, optional Size of points (default is `3`). title : str, optional Title of the plot (default depends on the `type` attribute). N : int, optional Number of observations that will be sampled from the `result` attribute before calculating the smooth line. This is for performance issues with large data. `None` means take all `result` (default is `50_000`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ _global_checks.global_check_import('statsmodels', 'smoothing line') # are there any other objects to plot? if objects is None: _df_list = [self.result.copy()] elif isinstance(objects, self.__class__): # allow for objects to be a single element _df_list = [self.result.copy(), objects.result.copy()] elif isinstance(objects, (list, tuple)): # objects as tuple or array _df_list = [self.result.copy()] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _df_list += [ob.result.copy()] else: _global_checks.global_raise_objects_class(objects, self.__class__) _df = pd.concat(_df_list) if isinstance(N, int) and smooth: if N < _df.shape[0]: _df = _df.sample(N, random_state=0, replace=False) fig = px.scatter(_df, x=variable, y=yvariable, hover_name='ids', color="label", trendline="lowess" if smooth else None, color_discrete_sequence=_theme.get_default_colors(len(_df_list), 'line')) \ .update_traces(dict(marker_size=marker_size, line_width=line_width)) # wait for https://github.com/plotly/plotly.py/pull/2558 to add hline to the plot fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': yvariable}) fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': variable}) fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none", margin={'t': 78, 'b': 71, 'r': 30}) if show: fig.show(config=_theme.get_default_config()) else: return fig
Ancestors
- dalex._explanation.Explanation
- abc.ABC
Methods
def fit(self, explainer)
-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
explainer
:Explainer object
- Model wrapper created using the Explainer class.
Returns
None
Expand source code Browse git
def fit(self, explainer): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. Returns ----------- None """ result = explainer.data.copy() # if variables = NULL then all variables are added # otherwise only selected if self.variables is not None: result = result.loc[:, _global_utils.intersect_unsorted(self.variables, result.columns)] # is there target if explainer.y is not None: result = result.assign(y=explainer.y) # are there predictions - add y_hat to the Explainer for the future if explainer.y_hat is None: explainer.y_hat = explainer.predict(explainer.data) # are there residuals - add residuals to the Explainer for the future if explainer.residuals is None: explainer.residuals = explainer.residual(explainer.data, explainer.y) self.result = result.assign( y_hat=explainer.y_hat, residuals=explainer.residuals, abs_residuals=np.abs(explainer.residuals), label=explainer.label, ids=np.arange(result.shape[0])+1 )
def plot(self, objects=None, variable='y_hat', yvariable='residuals', smooth=True, line_width=2, marker_size=3, title='Residual Diagnostics', N=50000, show=True)
-
Plot the Residual Diagnostics explanation
Parameters
objects
:ResidualDiagnostics object
orarray_like
ofResidualDiagnostics objects
- Additional objects to plot (default is
None
). variable
:str
, optional- Name of the variable from the
result
attribute to appear on the OX axis (default is'y_hat'
). yvariable
:str
, optional- Name of the variable from the
result
attribute to appear on the OY axis (default is'residuals'
). smooth
:bool
, optional- Add the smooth line (default is
True
). line_width
:float
, optional- Width of lines in px (default is
2
). marker_size
:float
, optional- Size of points (default is
3
). title
:str
, optional- Title of the plot (default depends on the
type
attribute). N
:int
, optional- Number of observations that will be sampled from the
result
attribute before calculating the smooth line. This is for performance issues with large data.None
means take allresult
(default is50_000
). show
:bool
, optionalTrue
shows the plot;False
returns the plotly Figure object that can be edited or saved using thewrite_image()
method (default isTrue
).
Returns
None
orplotly.graph_objects.Figure
- Return figure that can be edited or saved. See
show
parameter.
Expand source code Browse git
def plot(self, objects=None, variable="y_hat", yvariable="residuals", smooth=True, line_width=2, marker_size=3, title="Residual Diagnostics", N=50000, show=True): """Plot the Residual Diagnostics explanation Parameters ---------- objects : ResidualDiagnostics object or array_like of ResidualDiagnostics objects Additional objects to plot (default is `None`). variable : str, optional Name of the variable from the `result` attribute to appear on the OX axis (default is `'y_hat'`). yvariable : str, optional Name of the variable from the `result` attribute to appear on the OY axis (default is `'residuals'`). smooth : bool, optional Add the smooth line (default is `True`). line_width : float, optional Width of lines in px (default is `2`). marker_size : float, optional Size of points (default is `3`). title : str, optional Title of the plot (default depends on the `type` attribute). N : int, optional Number of observations that will be sampled from the `result` attribute before calculating the smooth line. This is for performance issues with large data. `None` means take all `result` (default is `50_000`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ _global_checks.global_check_import('statsmodels', 'smoothing line') # are there any other objects to plot? if objects is None: _df_list = [self.result.copy()] elif isinstance(objects, self.__class__): # allow for objects to be a single element _df_list = [self.result.copy(), objects.result.copy()] elif isinstance(objects, (list, tuple)): # objects as tuple or array _df_list = [self.result.copy()] for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _df_list += [ob.result.copy()] else: _global_checks.global_raise_objects_class(objects, self.__class__) _df = pd.concat(_df_list) if isinstance(N, int) and smooth: if N < _df.shape[0]: _df = _df.sample(N, random_state=0, replace=False) fig = px.scatter(_df, x=variable, y=yvariable, hover_name='ids', color="label", trendline="lowess" if smooth else None, color_discrete_sequence=_theme.get_default_colors(len(_df_list), 'line')) \ .update_traces(dict(marker_size=marker_size, line_width=line_width)) # wait for https://github.com/plotly/plotly.py/pull/2558 to add hline to the plot fig.update_yaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': yvariable}) fig.update_xaxes({'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True, 'title_text': variable}) fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none", margin={'t': 78, 'b': 71, 'r': 30}) if show: fig.show(config=_theme.get_default_config()) else: return fig
class VariableImportance (loss_function='rmse', type='variable_importance', N=1000, B=10, variables=None, variable_groups=None, keep_raw_permutations=True, processes=1, random_state=None)
-
Calculate model-level variable importance
Parameters
loss_function
:{'rmse', '1-auc', 'mse', 'mae', 'mad'}
orfunction
, optional- If string, then such loss function will be used to assess variable importance
(default is
'rmse'
or'1-auc', depends on
model_type` attribute). type
:{'variable_importance', 'ratio', 'difference'}
, optional- Type of transformation that will be applied to dropout loss.
N
:int
, optional- Number of observations that will be sampled from the
data
attribute before the calculation of variable importance.None
means alldata
(default is1000
). B
:int
, optional- Number of permutation rounds to perform on each variable (default is
10
). variables
:array_like
ofstr
, optional- Variables for which the importance will be calculated
(default is
None
, which means all of the variables). NOTE: Ignored ifvariable_groups
is not None. variable_groups
:dict
oflists
, optional- Group the variables to calculate their joint variable importance
e.g.
{'X': ['x1', 'x2'], 'Y': ['y1', 'y2']}
(default isNone
). keep_raw_permutations
:bool
, optional- Save results for all permutation rounds (default is
True
). processes
:int
, optional- Number of parallel processes to use in calculations. Iterated over
B
(default is1
, which means no parallel computation). random_state
:int
, optional- Set seed for random number generator (default is random seed).
Attributes
result
:pd.DataFrame
- Main result attribute of an explanation.
loss_function
:function
- Loss function used to assess the variable importance.
type
:{'variable_importance', 'ratio', 'difference'}
- Type of transformation that will be applied to dropout loss.
N
:int
- Number of observations that will be sampled from the
data
attribute before the calculation of variable importance. B
:int
- Number of permutation rounds to perform on each variable.
variables
:array_like
ofstr
orNone
- Variables for which the importance will be calculated
variable_groups
:dict
oflists
orNone
- Grouped variables to calculate their joint variable importance.
keep_raw_permutations
:bool
- Save the results for all permutation rounds.
permutation
:pd.DataFrame
orNone
- The results for all permutation rounds.
processes
:int
- Number of parallel processes to use in calculations. Iterated over
B
. random_state
:int
orNone
- Set seed for random number generator.
Notes
Expand source code Browse git
class VariableImportance(Explanation): """Calculate model-level variable importance Parameters ----------- loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional If string, then such loss function will be used to assess variable importance (default is `'rmse'` or `'1-auc', depends on `model_type` attribute). type : {'variable_importance', 'ratio', 'difference'}, optional Type of transformation that will be applied to dropout loss. N : int, optional Number of observations that will be sampled from the `data` attribute before the calculation of variable importance. `None` means all `data` (default is `1000`). B : int, optional Number of permutation rounds to perform on each variable (default is `10`). variables : array_like of str, optional Variables for which the importance will be calculated (default is `None`, which means all of the variables). NOTE: Ignored if `variable_groups` is not None. variable_groups : dict of lists, optional Group the variables to calculate their joint variable importance e.g. `{'X': ['x1', 'x2'], 'Y': ['y1', 'y2']}` (default is `None`). keep_raw_permutations: bool, optional Save results for all permutation rounds (default is `True`). processes : int, optional Number of parallel processes to use in calculations. Iterated over `B` (default is `1`, which means no parallel computation). random_state : int, optional Set seed for random number generator (default is random seed). Attributes ----------- result : pd.DataFrame Main result attribute of an explanation. loss_function : function Loss function used to assess the variable importance. type : {'variable_importance', 'ratio', 'difference'} Type of transformation that will be applied to dropout loss. N : int Number of observations that will be sampled from the `data` attribute before the calculation of variable importance. B : int Number of permutation rounds to perform on each variable. variables : array_like of str or None Variables for which the importance will be calculated variable_groups : dict of lists or None Grouped variables to calculate their joint variable importance. keep_raw_permutations: bool Save the results for all permutation rounds. permutation : pd.DataFrame or None The results for all permutation rounds. processes : int Number of parallel processes to use in calculations. Iterated over `B`. random_state : int or None Set seed for random number generator. Notes -------- - https://pbiecek.github.io/ema/featureImportance.html """ def __init__(self, loss_function='rmse', type='variable_importance', N=1000, B=10, variables=None, variable_groups=None, keep_raw_permutations=True, processes=1, random_state=None): _loss_function = checks.check_loss_function(loss_function) _B = checks.check_B(B) _type = checks.check_type(type) _random_state = checks.check_random_state(random_state) _keep_raw_permutations = checks.check_keep_raw_permutations(keep_raw_permutations, B) _processes = checks.check_processes(processes) self.loss_function = _loss_function self.type = _type self.N = N self.B = _B self.variables = variables self.variable_groups = variable_groups self.random_state = _random_state self.keep_raw_permutations = _keep_raw_permutations self.result = pd.DataFrame() self.permutation = None self.processes = _processes def _repr_html_(self): return self.result._repr_html_() def fit(self, explainer): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. Returns ----------- None """ # if `variable_groups` are not specified, then extract from `variables` self.variable_groups = checks.check_variable_groups(self.variable_groups, explainer) self.variables = checks.check_variables(self.variables, self.variable_groups, explainer) self.result, self.permutation = utils.calculate_variable_importance(explainer, self.type, self.loss_function, self.variables, self.N, self.B, explainer.label, self.processes, self.keep_raw_permutations, self.random_state) def plot(self, objects=None, max_vars=10, digits=3, rounding_function=np.around, bar_width=16, split=("model", "variable"), title="Variable Importance", vertical_spacing=None, show=True): """Plot the Variable Importance explanation Parameters ----------- objects : VariableImportance object or array_like of VariableImportance objects Additional objects to plot in subplots (default is `None`). max_vars : int, optional Maximum number of variables that will be presented for for each subplot (default is `10`). digits : int, optional Number of decimal places (`np.around`) to round contributions. See `rounding_function` parameter (default is `3`). rounding_function : function, optional A function that will be used for rounding numbers (default is `np.around`). bar_width : float, optional Width of bars in px (default is `16`). split : {'model', 'variable'}, optional Split the subplots by model or variable (default is `'model'`). title : str, optional Title of the plot (default is `"Variable Importance"`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.2/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ if isinstance(split, tuple): split = split[0] if split not in ("model", "variable"): raise TypeError("split should be 'model' or 'variable'") # are there any other objects to plot? if objects is None: n = 1 _result_df = self.result.copy() if split == 'variable': # force split by model if only one explainer split = 'model' elif isinstance(objects, self.__class__): # allow for objects to be a single element n = 2 _result_df = pd.concat([self.result.copy(), objects.result.copy()]) elif isinstance(objects, (list, tuple)): # objects as tuple or array n = len(objects) + 1 _result_df = self.result.copy() for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_df = pd.concat([_result_df, ob.result.copy()]) else: _global_checks.global_raise_objects_class(objects, self.__class__) dl = _result_df.loc[_result_df.variable != '_baseline_', 'dropout_loss'].to_numpy() min_max_margin = dl.ptp() * 0.15 min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin] # take out full model best_fits = _result_df[_result_df.variable == '_full_model_'] # this produces dropout_loss_x and dropout_loss_y columns _result_df = _result_df.merge(best_fits[['label', 'dropout_loss']], how="left", on="label") _result_df = _result_df[['label', 'variable', 'dropout_loss_x', 'dropout_loss_y']].rename( columns={'dropout_loss_x': 'dropout_loss', 'dropout_loss_y': 'full_model'}) # remove full_model and baseline _result_df = _result_df[(_result_df.variable != '_full_model_') & (_result_df.variable != '_baseline_')] # calculate order of bars or variable plots (split = 'variable') # get variable permutation perm = _result_df[['variable', 'dropout_loss']].groupby('variable').mean().reset_index(). \ sort_values('dropout_loss', ascending=False).variable.values plot_height = 78 + 71 colors = _theme.get_default_colors(n, 'bar') if vertical_spacing is None: vertical_spacing = 0.2 / n model_names = _result_df['label'].unique().tolist() if len(model_names) != n: raise ValueError('label must be unique for each model') if split == "model": # init plot fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='drop-out loss', subplot_titles=model_names) # split df by model df_list = [v for k, v in _result_df.groupby('label', sort=False)] for i, df in enumerate(df_list): m = df.shape[0] if max_vars is not None and max_vars < m: m = max_vars # take only m variables (for max_vars) # sort rows of df by variable permutation and drop unused variables df = df.sort_values('dropout_loss').tail(m) \ .set_index('variable').reindex(perm).dropna().reset_index() baseline = df.iloc[0, df.columns.get_loc('full_model')] df = df.assign(difference=lambda x: x['dropout_loss'] - baseline) lt = df.difference.apply(lambda val: "+"+str(rounding_function(np.abs(val), digits)) if val > 0 else str(rounding_function(np.abs(val), digits))) tt = df.apply(lambda row: plot.tooltip_text(row, rounding_function, digits), axis=1) df = df.assign(label_text=lt, tooltip_text=tt) fig.add_shape(type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x", line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1) fig.add_bar( orientation="h", y=df['variable'].tolist(), x=df['difference'].tolist(), textposition="outside", text=df['label_text'].tolist(), marker_color=colors[i], base=baseline, hovertext=df['tooltip_text'].tolist(), hoverinfo='text', hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'}, showlegend=False, row=i + 1, col=1 ) fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}, row=i + 1, col=1) fig.update_xaxes( {'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}, row=i + 1, col=1) plot_height += m * bar_width + (m + 1) * bar_width / 4 + 30 else: # split df by variable df_list = [v for k, v in _result_df.groupby('variable', sort=False)] n = len(df_list) if max_vars is not None and max_vars < n: n = max_vars if vertical_spacing is None: vertical_spacing = 0.2 / n # init plot variable_names = perm[0:n] fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='drop-out loss', subplot_titles=variable_names) df_dict = {e.variable.array[0]: e for e in df_list} # take only n=max_vars elements from df_dict for i in range(n): df = df_dict[perm[i]] m = df.shape[0] baseline = 0 df = df.assign(difference=lambda x: x['dropout_loss'] - x['full_model']) lt = df.difference.apply(lambda val: "+"+str(rounding_function(np.abs(val), digits)) if val > 0 else str(rounding_function(np.abs(val), digits))) tt = df.apply(lambda row: plot.tooltip_text(row, rounding_function, digits), axis=1) df = df.assign(label_text=lt, tooltip_text=tt) fig.add_shape(type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x", line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1) fig.add_bar( orientation="h", y=df['label'].tolist(), x=df['dropout_loss'].tolist(), # textposition="outside", # text=df['label_text'].tolist(), marker_color=colors, base=baseline, hovertext=df['tooltip_text'].tolist(), hoverinfo='text', hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'}, showlegend=False, row=i + 1, col=1) fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'dtick':1, 'fixedrange': True}, row=i + 1, col=1) fig.update_xaxes( {'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}, row=i + 1, col=1) plot_height += m * bar_width + (m + 1) * bar_width / 4 plot_height += (n - 1) * 70 fig.update_xaxes({'range': min_max}) fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none", height=plot_height, margin={'t': 78, 'b': 71, 'r': 30}) if show: fig.show(config=_theme.get_default_config()) else: return fig
Ancestors
- dalex._explanation.Explanation
- abc.ABC
Subclasses
- dalex.aspect._model_aspect_importance.object.ModelAspectImportance
Methods
def fit(self, explainer)
-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
explainer
:Explainer object
- Model wrapper created using the Explainer class.
Returns
None
Expand source code Browse git
def fit(self, explainer): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. Returns ----------- None """ # if `variable_groups` are not specified, then extract from `variables` self.variable_groups = checks.check_variable_groups(self.variable_groups, explainer) self.variables = checks.check_variables(self.variables, self.variable_groups, explainer) self.result, self.permutation = utils.calculate_variable_importance(explainer, self.type, self.loss_function, self.variables, self.N, self.B, explainer.label, self.processes, self.keep_raw_permutations, self.random_state)
def plot(self, objects=None, max_vars=10, digits=3, rounding_function=<function around>, bar_width=16, split=('model', 'variable'), title='Variable Importance', vertical_spacing=None, show=True)
-
Plot the Variable Importance explanation
Parameters
objects
:VariableImportance object
orarray_like
ofVariableImportance objects
- Additional objects to plot in subplots (default is
None
). max_vars
:int
, optional- Maximum number of variables that will be presented for for each subplot
(default is
10
). digits
:int
, optional- Number of decimal places (
np.around
) to round contributions. Seerounding_function
parameter (default is3
). rounding_function
:function
, optional- A function that will be used for rounding numbers (default is
np.around
). bar_width
:float
, optional- Width of bars in px (default is
16
). split
:{'model', 'variable'}
, optional- Split the subplots by model or variable (default is
'model'
). title
:str
, optional- Title of the plot (default is
"Variable Importance"
). vertical_spacing
:float <0, 1>
, optional- Ratio of vertical space between the plots (default is
0.2/number of rows
). show
:bool
, optionalTrue
shows the plot;False
returns the plotly Figure object that can be edited or saved using thewrite_image()
method (default isTrue
).
Returns
None
orplotly.graph_objects.Figure
- Return figure that can be edited or saved. See
show
parameter.
Expand source code Browse git
def plot(self, objects=None, max_vars=10, digits=3, rounding_function=np.around, bar_width=16, split=("model", "variable"), title="Variable Importance", vertical_spacing=None, show=True): """Plot the Variable Importance explanation Parameters ----------- objects : VariableImportance object or array_like of VariableImportance objects Additional objects to plot in subplots (default is `None`). max_vars : int, optional Maximum number of variables that will be presented for for each subplot (default is `10`). digits : int, optional Number of decimal places (`np.around`) to round contributions. See `rounding_function` parameter (default is `3`). rounding_function : function, optional A function that will be used for rounding numbers (default is `np.around`). bar_width : float, optional Width of bars in px (default is `16`). split : {'model', 'variable'}, optional Split the subplots by model or variable (default is `'model'`). title : str, optional Title of the plot (default is `"Variable Importance"`). vertical_spacing : float <0, 1>, optional Ratio of vertical space between the plots (default is `0.2/number of rows`). show : bool, optional `True` shows the plot; `False` returns the plotly Figure object that can be edited or saved using the `write_image()` method (default is `True`). Returns ----------- None or plotly.graph_objects.Figure Return figure that can be edited or saved. See `show` parameter. """ if isinstance(split, tuple): split = split[0] if split not in ("model", "variable"): raise TypeError("split should be 'model' or 'variable'") # are there any other objects to plot? if objects is None: n = 1 _result_df = self.result.copy() if split == 'variable': # force split by model if only one explainer split = 'model' elif isinstance(objects, self.__class__): # allow for objects to be a single element n = 2 _result_df = pd.concat([self.result.copy(), objects.result.copy()]) elif isinstance(objects, (list, tuple)): # objects as tuple or array n = len(objects) + 1 _result_df = self.result.copy() for ob in objects: _global_checks.global_check_object_class(ob, self.__class__) _result_df = pd.concat([_result_df, ob.result.copy()]) else: _global_checks.global_raise_objects_class(objects, self.__class__) dl = _result_df.loc[_result_df.variable != '_baseline_', 'dropout_loss'].to_numpy() min_max_margin = dl.ptp() * 0.15 min_max = [dl.min() - min_max_margin, dl.max() + min_max_margin] # take out full model best_fits = _result_df[_result_df.variable == '_full_model_'] # this produces dropout_loss_x and dropout_loss_y columns _result_df = _result_df.merge(best_fits[['label', 'dropout_loss']], how="left", on="label") _result_df = _result_df[['label', 'variable', 'dropout_loss_x', 'dropout_loss_y']].rename( columns={'dropout_loss_x': 'dropout_loss', 'dropout_loss_y': 'full_model'}) # remove full_model and baseline _result_df = _result_df[(_result_df.variable != '_full_model_') & (_result_df.variable != '_baseline_')] # calculate order of bars or variable plots (split = 'variable') # get variable permutation perm = _result_df[['variable', 'dropout_loss']].groupby('variable').mean().reset_index(). \ sort_values('dropout_loss', ascending=False).variable.values plot_height = 78 + 71 colors = _theme.get_default_colors(n, 'bar') if vertical_spacing is None: vertical_spacing = 0.2 / n model_names = _result_df['label'].unique().tolist() if len(model_names) != n: raise ValueError('label must be unique for each model') if split == "model": # init plot fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='drop-out loss', subplot_titles=model_names) # split df by model df_list = [v for k, v in _result_df.groupby('label', sort=False)] for i, df in enumerate(df_list): m = df.shape[0] if max_vars is not None and max_vars < m: m = max_vars # take only m variables (for max_vars) # sort rows of df by variable permutation and drop unused variables df = df.sort_values('dropout_loss').tail(m) \ .set_index('variable').reindex(perm).dropna().reset_index() baseline = df.iloc[0, df.columns.get_loc('full_model')] df = df.assign(difference=lambda x: x['dropout_loss'] - baseline) lt = df.difference.apply(lambda val: "+"+str(rounding_function(np.abs(val), digits)) if val > 0 else str(rounding_function(np.abs(val), digits))) tt = df.apply(lambda row: plot.tooltip_text(row, rounding_function, digits), axis=1) df = df.assign(label_text=lt, tooltip_text=tt) fig.add_shape(type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x", line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1) fig.add_bar( orientation="h", y=df['variable'].tolist(), x=df['difference'].tolist(), textposition="outside", text=df['label_text'].tolist(), marker_color=colors[i], base=baseline, hovertext=df['tooltip_text'].tolist(), hoverinfo='text', hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'}, showlegend=False, row=i + 1, col=1 ) fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'fixedrange': True}, row=i + 1, col=1) fig.update_xaxes( {'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}, row=i + 1, col=1) plot_height += m * bar_width + (m + 1) * bar_width / 4 + 30 else: # split df by variable df_list = [v for k, v in _result_df.groupby('variable', sort=False)] n = len(df_list) if max_vars is not None and max_vars < n: n = max_vars if vertical_spacing is None: vertical_spacing = 0.2 / n # init plot variable_names = perm[0:n] fig = make_subplots(rows=n, cols=1, shared_xaxes=True, vertical_spacing=vertical_spacing, x_title='drop-out loss', subplot_titles=variable_names) df_dict = {e.variable.array[0]: e for e in df_list} # take only n=max_vars elements from df_dict for i in range(n): df = df_dict[perm[i]] m = df.shape[0] baseline = 0 df = df.assign(difference=lambda x: x['dropout_loss'] - x['full_model']) lt = df.difference.apply(lambda val: "+"+str(rounding_function(np.abs(val), digits)) if val > 0 else str(rounding_function(np.abs(val), digits))) tt = df.apply(lambda row: plot.tooltip_text(row, rounding_function, digits), axis=1) df = df.assign(label_text=lt, tooltip_text=tt) fig.add_shape(type='line', x0=baseline, x1=baseline, y0=-1, y1=m, yref="paper", xref="x", line={'color': "#371ea3", 'width': 1.5, 'dash': 'dot'}, row=i + 1, col=1) fig.add_bar( orientation="h", y=df['label'].tolist(), x=df['dropout_loss'].tolist(), # textposition="outside", # text=df['label_text'].tolist(), marker_color=colors, base=baseline, hovertext=df['tooltip_text'].tolist(), hoverinfo='text', hoverlabel={'bgcolor': 'rgba(0,0,0,0.8)'}, showlegend=False, row=i + 1, col=1) fig.update_yaxes({'type': 'category', 'autorange': 'reversed', 'gridwidth': 2, 'automargin': True, 'ticks': 'outside', 'tickcolor': 'white', 'ticklen': 10, 'dtick':1, 'fixedrange': True}, row=i + 1, col=1) fig.update_xaxes( {'type': 'linear', 'gridwidth': 2, 'zeroline': False, 'automargin': True, 'ticks': "outside", 'tickcolor': 'white', 'ticklen': 3, 'fixedrange': True}, row=i + 1, col=1) plot_height += m * bar_width + (m + 1) * bar_width / 4 plot_height += (n - 1) * 70 fig.update_xaxes({'range': min_max}) fig.update_layout(title_text=title, title_x=0.15, font={'color': "#371ea3"}, template="none", height=plot_height, margin={'t': 78, 'b': 71, 'r': 30}) if show: fig.show(config=_theme.get_default_config()) else: return fig