Module dalex.wrappers
Expand source code Browse git
from ._shap.object import ShapWrapper
__all__ = [
"ShapWrapper"
]
Classes
class ShapWrapper (type)-
Explanation wrapper for the
shappackageThis object uses the shap package to create the model explanation. See ttps://github.com/slundberg/shap
Parameters
type:{'predict_parts', 'model_parts'}- Method type for calculations.
Attributes
result:listornumpy.ndarray- Calculated shap values for
new_observationdata. shap_explainer:{shap.TreeExplainer, shap.DeepExplainer, shap.GradientExplainer, shap.LinearExplainer, shap.KernelExplainer}- Explainer object from the
shappackage. shap_explainer_type:{'TreeExplainer', 'DeepExplainer',- 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'} String name of the Explainer class.
new_observation:pandas.Seriesorpandas.DataFrame- Observations for which the shap values will be calculated
(later stored in
result). type:{'predict_parts', 'model_parts'}- Method type for calculations.
Notes
Expand source code Browse git
class ShapWrapper(Explanation): """Explanation wrapper for the `shap` package This object uses the shap package to create the model explanation. See ttps://github.com/slundberg/shap Parameters ---------- type : {'predict_parts', 'model_parts'} Method type for calculations. Attributes ---------- result : list or numpy.ndarray Calculated shap values for `new_observation` data. shap_explainer : {shap.TreeExplainer, shap.DeepExplainer, shap.GradientExplainer, shap.LinearExplainer, shap.KernelExplainer} Explainer object from the `shap` package. shap_explainer_type : {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'} String name of the Explainer class. new_observation : pandas.Series or pandas.DataFrame Observations for which the shap values will be calculated (later stored in `result`). type : {'predict_parts', 'model_parts'} Method type for calculations. Notes ---------- - https://github.com/slundberg/shap """ def __init__(self, type): self.shap_explainer = None self.type = type self.result = None self.new_observation = None self.shap_explainer_type = None def fit(self, explainer, new_observation, shap_explainer_type=None, **kwargs): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.Series or np.ndarray An observation for which a prediction needs to be explained. shap_explainer_type : {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'} String name of the Explainer class (default is `None`, which automatically chooses an Explainer to use). kwargs: dict Keyword parameters passed to the `shapley_values` method. Returns ----------- None """ from shap import TreeExplainer, DeepExplainer, GradientExplainer, LinearExplainer, KernelExplainer checks.check_compatibility(explainer) shap_explainer_type = checks.check_shap_explainer_type(shap_explainer_type, explainer.model) if self.type == 'predict_parts': new_observation = checks.check_new_observation_predict_parts(new_observation, explainer) if shap_explainer_type == "TreeExplainer": try: self.shap_explainer = TreeExplainer(explainer.model, explainer.data.values) except: # https://github.com/ModelOriented/DALEX/issues/371 self.shap_explainer = TreeExplainer(explainer.model) elif shap_explainer_type == "DeepExplainer": self.shap_explainer = DeepExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "GradientExplainer": self.shap_explainer = GradientExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "LinearExplainer": self.shap_explainer = LinearExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "KernelExplainer": self.shap_explainer = KernelExplainer( lambda x: explainer.predict(x), explainer.data.values ) self.result = self.shap_explainer.shap_values(new_observation.values, **kwargs) self.new_observation = new_observation self.shap_explainer_type = shap_explainer_type def plot(self, **kwargs): """Plot the Shap Wrapper Parameters ---------- kwargs : dict Keyword arguments passed to one of the: - shap.force_plot when type is `'predict_parts'` - shap.summary_plot when type is `'model_parts'` Exceptions are: `base_value`, `shap_values`, `features` and `feature_names`. Returns ----------- None Notes -------- - https://github.com/slundberg/shap """ from shap import force_plot, summary_plot if self.type == 'predict_parts': if isinstance(self.shap_explainer.expected_value, (np.ndarray, list)): base_value = self.shap_explainer.expected_value[1] else: base_value = self.shap_explainer.expected_value shap_values = self.result[1] if isinstance(self.result, list) else self.result force_plot(base_value=base_value, shap_values=shap_values, features=self.new_observation.values, feature_names=self.new_observation.columns, matplotlib=True, **kwargs) elif self.type == 'model_parts': summary_plot(shap_values=self.result, features=self.new_observation, **kwargs)Ancestors
- dalex._explanation.Explanation
- abc.ABC
Methods
def fit(self, explainer, new_observation, shap_explainer_type=None, **kwargs)-
Calculate the result of explanation
Fit method makes calculations in place and changes the attributes.
Parameters
explainer:Explainer object- Model wrapper created using the Explainer class.
new_observation:pd.Seriesornp.ndarray- An observation for which a prediction needs to be explained.
shap_explainer_type:{'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'}- String name of the Explainer class (default is
None, which automatically chooses an Explainer to use). kwargs:dict- Keyword parameters passed to the
shapley_valuesmethod.
Returns
None
Expand source code Browse git
def fit(self, explainer, new_observation, shap_explainer_type=None, **kwargs): """Calculate the result of explanation Fit method makes calculations in place and changes the attributes. Parameters ----------- explainer : Explainer object Model wrapper created using the Explainer class. new_observation : pd.Series or np.ndarray An observation for which a prediction needs to be explained. shap_explainer_type : {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'} String name of the Explainer class (default is `None`, which automatically chooses an Explainer to use). kwargs: dict Keyword parameters passed to the `shapley_values` method. Returns ----------- None """ from shap import TreeExplainer, DeepExplainer, GradientExplainer, LinearExplainer, KernelExplainer checks.check_compatibility(explainer) shap_explainer_type = checks.check_shap_explainer_type(shap_explainer_type, explainer.model) if self.type == 'predict_parts': new_observation = checks.check_new_observation_predict_parts(new_observation, explainer) if shap_explainer_type == "TreeExplainer": try: self.shap_explainer = TreeExplainer(explainer.model, explainer.data.values) except: # https://github.com/ModelOriented/DALEX/issues/371 self.shap_explainer = TreeExplainer(explainer.model) elif shap_explainer_type == "DeepExplainer": self.shap_explainer = DeepExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "GradientExplainer": self.shap_explainer = GradientExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "LinearExplainer": self.shap_explainer = LinearExplainer(explainer.model, explainer.data.values) elif shap_explainer_type == "KernelExplainer": self.shap_explainer = KernelExplainer( lambda x: explainer.predict(x), explainer.data.values ) self.result = self.shap_explainer.shap_values(new_observation.values, **kwargs) self.new_observation = new_observation self.shap_explainer_type = shap_explainer_type def plot(self, **kwargs)-
Plot the Shap Wrapper
Parameters
kwargs:dict- Keyword arguments passed to one of the:
- shap.force_plot when type is
'predict_parts'- shap.summary_plot when type is'model_parts'Exceptions are:base_value,shap_values,featuresandfeature_names.
Returns
None
Notes
Expand source code Browse git
def plot(self, **kwargs): """Plot the Shap Wrapper Parameters ---------- kwargs : dict Keyword arguments passed to one of the: - shap.force_plot when type is `'predict_parts'` - shap.summary_plot when type is `'model_parts'` Exceptions are: `base_value`, `shap_values`, `features` and `feature_names`. Returns ----------- None Notes -------- - https://github.com/slundberg/shap """ from shap import force_plot, summary_plot if self.type == 'predict_parts': if isinstance(self.shap_explainer.expected_value, (np.ndarray, list)): base_value = self.shap_explainer.expected_value[1] else: base_value = self.shap_explainer.expected_value shap_values = self.result[1] if isinstance(self.result, list) else self.result force_plot(base_value=base_value, shap_values=shap_values, features=self.new_observation.values, feature_names=self.new_observation.columns, matplotlib=True, **kwargs) elif self.type == 'model_parts': summary_plot(shap_values=self.result, features=self.new_observation, **kwargs)