Module dalex.wrappers
Expand source code Browse git
from ._shap.object import ShapWrapper
__all__ = [
"ShapWrapper"
]
Classes
class ShapWrapper (type)
-
Explanation wrapper for the
shap
packageThis object uses the shap package to create the model explanation. See ttps://github.com/slundberg/shap
Parameters
type
:{'predict_parts', 'model_parts'}
- Method type for calculations.
Attributes
result
:list
ornumpy.ndarray
- Calculated shap values for
new_observation
data. shap_explainer
:{shap.TreeExplainer, shap.DeepExplainer, shap.GradientExplainer, shap.LinearExplainer, shap.KernelExplainer}
- Explainer object from the
shap
package. shap_explainer_type
:{'TreeExplainer', 'DeepExplainer',
- 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'} String name of the Explainer class.
new_observation
:pandas.Series
orpandas.DataFrame
- Observations for which the shap values will be calculated
(later stored in
result
). type
:{'predict_parts', 'model_parts'}
- Method type for calculations.
Notes
Expand source code Browse git
class ShapWrapper(Explanation): """Explanation wrapper for the `shap` package This object uses the shap package to create the model explanation. See ttps://github.com/slundberg/shap Parameters ---------- type : {'predict_parts', 'model_parts'} Method type for calculations. Attributes ---------- result : list or numpy.ndarray Calculated shap values for `new_observation` data. shap_explainer : {shap.TreeExplainer, shap.DeepExplainer, shap.GradientExplainer, shap.LinearExplainer, shap.KernelExplainer} Explainer object from the `shap` package. shap_explainer_type : {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'} String name of the Explainer class. new_observation : pandas.Series or pandas.DataFrame Observations for which the shap values will be calculated (later stored in `result`). type : {'predict_parts', 'model_parts'} Method type for calculations. Notes ---------- - https://github.com/slundberg/shap """ def __init__(self, type): self.shap_explainer = None self.type = type self.result = None self.new_observation = None self.shap_explainer_type = None def fit(self, explainer, new_observation, shap_explainer_type=None, **kwargs): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.Series or np.ndarray An observation for which a prediction needs to be explained. shap_explainer_type : {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'} String name of the Explainer class (default is `None`, which automatically chooses an Explainer to use). kwargs: dict Keyword parameters passed to the `shapley_values` method. Returns ----------- None """ from shap import TreeExplainer, DeepExplainer, GradientExplainer, LinearExplainer, KernelExplainer checks.check_compatibility(explainer) shap_explainer_type = checks.check_shap_explainer_type(shap_explainer_type, explainer.model) if self.type == 'predict_parts': new_observation = checks.check_new_observation_predict_parts(new_observation, explainer) if shap_explainer_type == "TreeExplainer": try: self.shap_explainer = TreeExplainer(explainer.model, explainer.data.values) except: # https://github.com/ModelOriented/DALEX/issues/371 self.shap_explainer = TreeExplainer(explainer.model) elif shap_explainer_type == "DeepExplainer": self.shap_explainer = DeepExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "GradientExplainer": self.shap_explainer = GradientExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "LinearExplainer": self.shap_explainer = LinearExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "KernelExplainer": self.shap_explainer = KernelExplainer( lambda x: explainer.predict(x), explainer.data.values ) self.result = self.shap_explainer.shap_values(new_observation.values, **kwargs) self.new_observation = new_observation self.shap_explainer_type = shap_explainer_type def plot(self, **kwargs): """Plot the Shap Wrapper Parameters ---------- kwargs : dict Keyword arguments passed to one of the: - shap.force_plot when type is `'predict_parts'` - shap.summary_plot when type is `'model_parts'` Exceptions are: `base_value`, `shap_values`, `features` and `feature_names`. Returns ----------- None Notes -------- - https://github.com/slundberg/shap """ from shap import force_plot, summary_plot if self.type == 'predict_parts': if isinstance(self.shap_explainer.expected_value, (np.ndarray, list)): base_value = self.shap_explainer.expected_value[1] else: base_value = self.shap_explainer.expected_value shap_values = self.result[1] if isinstance(self.result, list) else self.result force_plot(base_value=base_value, shap_values=shap_values, features=self.new_observation.values, feature_names=self.new_observation.columns, matplotlib=True, **kwargs) elif self.type == 'model_parts': summary_plot(shap_values=self.result, features=self.new_observation, **kwargs)
Ancestors
- dalex._explanation.Explanation
- abc.ABC
Methods
def fit(self, explainer, new_observation, shap_explainer_type=None, **kwargs)
-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
explainer
:Explainer object
- Model wrapper created using the Explainer class.
new_observation
:pd.Series
ornp.ndarray
- An observation for which a prediction needs to be explained.
shap_explainer_type
:{'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'}
- String name of the Explainer class (default is
None
, which automatically chooses an Explainer to use). kwargs
:dict
- Keyword parameters passed to the
shapley_values
method.
Returns
None
Expand source code Browse git
def fit(self, explainer, new_observation, shap_explainer_type=None, **kwargs): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.Series or np.ndarray An observation for which a prediction needs to be explained. shap_explainer_type : {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'} String name of the Explainer class (default is `None`, which automatically chooses an Explainer to use). kwargs: dict Keyword parameters passed to the `shapley_values` method. Returns ----------- None """ from shap import TreeExplainer, DeepExplainer, GradientExplainer, LinearExplainer, KernelExplainer checks.check_compatibility(explainer) shap_explainer_type = checks.check_shap_explainer_type(shap_explainer_type, explainer.model) if self.type == 'predict_parts': new_observation = checks.check_new_observation_predict_parts(new_observation, explainer) if shap_explainer_type == "TreeExplainer": try: self.shap_explainer = TreeExplainer(explainer.model, explainer.data.values) except: # https://github.com/ModelOriented/DALEX/issues/371 self.shap_explainer = TreeExplainer(explainer.model) elif shap_explainer_type == "DeepExplainer": self.shap_explainer = DeepExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "GradientExplainer": self.shap_explainer = GradientExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "LinearExplainer": self.shap_explainer = LinearExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "KernelExplainer": self.shap_explainer = KernelExplainer( lambda x: explainer.predict(x), explainer.data.values ) self.result = self.shap_explainer.shap_values(new_observation.values, **kwargs) self.new_observation = new_observation self.shap_explainer_type = shap_explainer_type
def plot(self, **kwargs)
-
Plot the Shap Wrapper
Parameters
kwargs
:dict
- Keyword arguments passed to one of the:
- shap.force_plot when type is
'predict_parts'
- shap.summary_plot when type is'model_parts'
Exceptions are:base_value
,shap_values
,features
andfeature_names
.
Returns
None
Notes
Expand source code Browse git
def plot(self, **kwargs): """Plot the Shap Wrapper Parameters ---------- kwargs : dict Keyword arguments passed to one of the: - shap.force_plot when type is `'predict_parts'` - shap.summary_plot when type is `'model_parts'` Exceptions are: `base_value`, `shap_values`, `features` and `feature_names`. Returns ----------- None Notes -------- - https://github.com/slundberg/shap """ from shap import force_plot, summary_plot if self.type == 'predict_parts': if isinstance(self.shap_explainer.expected_value, (np.ndarray, list)): base_value = self.shap_explainer.expected_value[1] else: base_value = self.shap_explainer.expected_value shap_values = self.result[1] if isinstance(self.result, list) else self.result force_plot(base_value=base_value, shap_values=shap_values, features=self.new_observation.values, feature_names=self.new_observation.columns, matplotlib=True, **kwargs) elif self.type == 'model_parts': summary_plot(shap_values=self.result, features=self.new_observation, **kwargs)