Package dalex

dalex: Responsible Machine Learning in Python

Python-check Supported Python
versions PyPI version Downloads

Overview

Unverified black box model is the path to the failure. Opaqueness leads to distrust. Distrust leads to ignoration. Ignoration leads to rejection.

The dalex package xrays any model and helps to explore and explain its behaviour, helps to understand how complex models are working. The main Explainer object creates a wrapper around a predictive model. Wrapped models may then be explored and compared with a collection of model-level and predict-level explanations. Moreover, there are fairness methods and interactive exploration dashboards available to the user.

The philosophy behind dalex explanations is described in the Explanatory Model Analysis e-book.

Installation

The dalex package is available on PyPI

pip install dalex -U

Examples

Plots

This package uses plotly to render the plots:

Citation

If you use dalex, please cite our paper:

@article{dalex,
  title={dalex: Responsible Machine Learning with Interactive
         Explainability and Fairness in Python},
  author={Hubert Baniecki and Wojciech Kretowicz and Piotr Piatyszek
          and Jakub Wisniewski and Przemyslaw Biecek},
  year={2020},
  eprint={2012.14406},
  archivePrefix={arXiv},
  url={https://arxiv.org/abs/2012.14406}
}

Developer

Class diagram

Folder structure


Expand source code
"""
.. include:: ./documentation.md
"""


from . import datasets
from ._explainer.object import Explainer
from .arena.object import Arena

__version__ = '1.0.0'

__all__ = [
    "Arena",
    "datasets",
    "Explainer",
    "fairness"
]

# specify autocompletion in IPython
# see comment: https://github.com/ska-telescope/katpoint/commit/ed7e8b9e389ee035073c62c2394975fe71031f88
# __dir__ docs (Python 3.7!): https://docs.python.org/3.7/library/functions.html#dir

def __dir__():
    """IPython tab completion seems to respect this."""
    return __all__ +  ['__all__', '__builtins__', '__cached__', '__doc__', '__file__',
                       '__loader__', '__name__', '__package__', '__path__', '__spec__',
                       '__version__']

Sub-modules

dalex.arena
dalex.datasets
dalex.fairness
dalex.model_explanations
dalex.predict_explanations
dalex.wrappers

Classes

class Arena (precalculate=False, enable_attributes=True, enable_custom_params=True)

Creates Arena object

This class should be used to create Python connector for Arena dashboard. Initialized object can work both in static and live mode. Use push_* methods to add your models, observations and datasets.

Parameters

precalculate : bool
Enables precalculating plots after using each push_* method.
enable_attributes : bool
Enables attributes of observations and variables. Attributes are required to display details of observation in Arena, but it also increases generated file size.
enable_custom_params : bool
Enables modififying observations in dashboard. It requires attributes and works only in live version.

Attributes

models : list of ModelParam objects
List of pushed models encapsulated in ModelParam class
observations : list of ObservationParam objects
List of pushed observations encapsulated in ObservationParam class
datasets : list of DatasetParam objects
List of pushed datasets encapsulated in DatasetParam class
variables_cache : list of VariableParam objects
Cached list of VariableParam objects generated using pushed models and datasets
server_thread : threading.Thread
Thread of running server or None otherwise
precalculate : bool
if plots should be precalculated
enable_attributes : bool
if attributes are enabled
enable_custom_params : bool
if modifying observations is enabled
timestamp : float
timestamp of last modification
cache : list of PlotContainer objects
List of already calculated plots
mutex : _thread.lock
Mutex for params and cache
plots : list of classes extending PlotContainer
List of enabled plots
options : dict
Options for plots

Notes

For tutorial look at https://arena.drwhy.ai/docs/guide/first-datasource

Expand source code
class Arena:
    """ Creates Arena object

    This class should be used to create Python connector for Arena dashboard. Initialized
    object can work both in static and live mode. Use `push_*` methods to add your
    models, observations and datasets.

    Parameters
    ----------
    precalculate : bool
        Enables precalculating plots after using each `push_*` method.
    enable_attributes : bool
        Enables attributes of observations and variables. Attributes are required to
        display details of observation in Arena, but it also increases generated
        file size.
    enable_custom_params : bool
        Enables modififying observations in dashboard. It requires attributes and works
        only in live version.

    Attributes
    --------
    models : list of ModelParam objects
        List of pushed models encapsulated in ModelParam class
    observations : list of ObservationParam objects
        List of pushed observations encapsulated in ObservationParam class
    datasets : list of DatasetParam objects
        List of pushed datasets encapsulated in DatasetParam class
    variables_cache : list of VariableParam objects
        Cached list of VariableParam objects generated using pushed models and datasets
    server_thread : threading.Thread
        Thread of running server or None otherwise
    precalculate : bool
        if plots should be precalculated
    enable_attributes : bool
        if attributes are enabled
    enable_custom_params : bool
        if modifying observations is enabled
    timestamp : float
        timestamp of last modification
    cache : list of PlotContainer objects
        List of already calculated plots
    mutex : _thread.lock
        Mutex for params and cache
    plots : list of classes extending PlotContainer
        List of enabled plots
    options : dict
        Options for plots

    Notes
    --------
    For tutorial look at https://arena.drwhy.ai/docs/guide/first-datasource

    """
    def __init__(self, precalculate=False, enable_attributes=True, enable_custom_params=True):
        self.models = []
        self.observations = []
        self.datasets = []
        self.variables_cache = []
        self.server_thread = None
        self.precalculate = bool(precalculate)
        self.enable_attributes = bool(enable_attributes)
        self.enable_custom_params = bool(enable_custom_params)
        self.timestamp = datetime.timestamp(datetime.now())
        self.cache = []
        self.mutex = threading.Lock()
        self.plots = [
            ShapleyValuesContainer,
            FeatureImportanceContainer,
            PartialDependenceContainer,
            AccumulatedDependenceContainer,
            CeterisParibusContainer,
            BreakDownContainer,
            MetricsContainer,
            ROCContainer,
            FairnessCheckContainer
        ]
        self.options = {}
        for plot in self.plots:
            options = {}
            for o in plot.options.keys():
                options[o] = plot.options.get(o).get('default')
            self.options[plot.info.get('plotType')] = options

    def get_supported_plots(self):
        """Returns plots classes that can produce at least one valid chart for this Arena.

        Returns
        -----------
        List of classes extending PlotContainer
        """
        return [plot for plot in self.plots if plot.test_arena(self)]

    def run_server(self,
                   host='127.0.0.1',
                   port=8181,
                   append_data=False,
                   arena_url='https://arena.drwhy.ai/',
                   disable_logs=True):
        """Starts server for live mode of Arena

        Parameters
        -----------
        host : str
            ip or hostname for the server
        port : int
            port number for the server
        append_data : bool
            if generated link should append data to already existing Arena window.
        arena_url : str
            URl of Arena dhasboard
        disable_logs : str
            if logs should be muted

        Notes
        --------
        Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

        Returns
        -----------
        Link to Arena
        """
        if self.server_thread:
            raise Exception('Server is already running. To stop ip use arena.stop_server().')
        global_check_import('flask')
        global_check_import('flask_cors')
        global_check_import('requests')
        self.server_thread = threading.Thread(target=start_server, args=(self, host, port, disable_logs))
        self.server_thread.start()
        if append_data:
            print(arena_url + '?append=http://' + host + ':' + str(port) + '/')
        else:
            print(arena_url + '?data=http://' + host + ':' + str(port) + '/')

    def stop_server(self):
        """Stops running server"""
        if not self.server_thread:
            raise Exception('Server is not running')
        self._stop_server()
        self.server_thread.join()
        self.server_thread = None

    def update_timestamp(self):
        """Updates timestamp

        Notes
        -------
        This function must be called from mutex context
        """
        now = datetime.now()
        self.timestamp = datetime.timestamp(now)

    def clear_cache(self, plot_type=None):
        """Clears cache

        Parameters
        -----------
        plot_type : str or None
            If None all cache is cleared. Otherwise only plots with
            provided plot_type are removed.

        Notes
        -------
        This function must be called from mutex context
        """
        if plot_type is None:
            self.cache = []
        else:
            self.cache = list(filter(lambda p: p.plot_type != plot_type, self.cache))
        self.update_timestamp()

    def find_in_cache(self, plot_type, params):
        """Function searches for cached plot

        Parameters
        -----------
        plot_type : str
            Value of plot_type field, that requested plot must have
        params : dict
            Keys of this dict are params types (model, observation, variable, dataset)
            and values are corresponding params labels. Requested plot must have equal
            params field.

        Returns
        --------
        PlotContainer or None
        """

        _filter = lambda p: p.plot_type == plot_type and params == p.params
        with self.mutex:
            return next(filter(_filter, self.cache), None)

    def put_to_cache(self, plot_container):
        """Puts new plot to cache

        Parameters
        -----------
        plot_container : PlotContainer
        """
        if not isinstance(plot_container, PlotContainer):
            raise Exception('Invalid plot container')
        with self.mutex:
            self.cache.append(plot_container)

    def fill_cache(self, fixed_params={}):
        """Generates all available plots and cache them

        This function tries to generate all plots that are not cached already and
        put them to cache. Range of generated plots can be narrow using `fixed_params`

        Parameters
        -----------
        fixed_params : dict
            This dict specifies which plots should be generated. Only those with
            all keys from `fixed_params` present and having the same value will be
            calculated.
        """
        if not isinstance(fixed_params, dict):
            raise Exception('Params argument must be a dict')
        for plot_class in self.get_supported_plots():
            required_params = plot_class.info.get('requiredParams')
            # Test if all params fixed by user are used in this plot. If not, then skip it.
            # This list contains fixed params' types, that are not required by plot.
            # Loop will be skipped if this list is not empty.
            if len([k for k in fixed_params.keys() if not k in required_params]) > 0:
                continue
            available_params = self.get_available_params()
            iteration_pools = map(lambda p: available_params.get(p) if fixed_params.get(p) is None else [fixed_params.get(p)], required_params)
            combinations = [[]]
            for pool in iteration_pools:
                combinations = [x + [y] for x in combinations for y in pool]
            for params_values in combinations:
                params = dict(zip(required_params, params_values))
                self.get_plot(plot_type=plot_class.info.get('plotType'), params_values=params)

    def push_model(self, explainer, precalculate=None):
        """Adds model to Arena

        This method encapsulate explainer in ModelParam object and
        save appends models fields. When precalculation is enabled
        triggers filling cache.

        Parameters
        -----------
        explainer : dalex.Explainer
            Explainer created using dalex package
        precalculate : bool or None
            Overrides constructor `precalculate` parameter when it is not None.
            If true, then only plots using this model will be precalculated.
        """
        if not isinstance(explainer, Explainer):
            raise Exception('Invalid Explainer argument')
        if explainer.label in self.list_params('model'):
            raise Exception('Explainer with the same label was already added')
        precalculate = self.precalculate if precalculate is None else bool(precalculate)
        param = ModelParam(explainer)
        with self.mutex:
            self.update_timestamp()
            self.models.append(param)
            self.variables_cache = []
        if precalculate:
            self.fill_cache({'model': param})

    def push_observations(self, observations, precalculate=None):
        """Adds observations to Arena

        Pushed observations will be used to local explainations. Function
        creates ObservationParam object for each row of pushed dataset. Label
        for each observation is taken from row name. When precalculation
        is enabled triggers filling cache.

        Parameters
        -----------
        observations : pandas.DataFrame
            Data frame of observations to be explained using instance level plots.
            Label for each observation is taken from row name.
        precalculate : bool or None
            Overrides constructor `precalculate` parameter when it is not None.
            If true, then only plots using thease observations will be precalculated.
        """
        if not isinstance(observations, DataFrame):
            raise Exception('Observations argument is not a pandas DataFrame')
        if len(observations.index.names) != 1:
            raise Exception('Observations argument need to have only one index')
        if not observations.index.is_unique:
            raise Exception('Observations argument need to have unique indexes')
        precalculate = self.precalculate if precalculate is None else bool(precalculate)
        old_observations = self.list_params('observation')
        observations = observations.set_index(observations.index.astype(str))
        params_objects = []
        for x in observations.index:
            if x in old_observations:
                raise Exception('Indexes of observations need to be unique across all observations')
            params_objects.append(ObservationParam(dataset=observations, index=x))
        with self.mutex:
            self.update_timestamp()
            self.observations.extend(params_objects)
        if precalculate:
            for obs in params_objects:
                self.fill_cache({'observation': obs})

    def push_dataset(self, dataset, target, label, precalculate=None):
        """Adds dataset to Arena

        Pushed dataset will visualised using exploratory data analysis plots.
        Function creates DatasetParam object with specified label and target name.
        When precalculation is enabled triggers filling cache.

        Parameters
        -----------
        dataset : pandas.DataFrame
            Data frame to be visualised using EDA plots. This
            dataset should contain target variable.
        target : str
            Name of target column
        label : str
            Label for this dataset
        precalculate : bool or None
            Overrides constructor `precalculate` parameter when it is not None.
            If true, then only plots using this model will be precalculated.
        """
        if not isinstance(dataset, DataFrame):
            raise Exception('Dataset argument is not a pandas DataFrame')
        if len(dataset.columns.names) != 1:
            raise Exception('Dataset argument need to have only one level column names')
        precalculate = self.precalculate if precalculate is None else bool(precalculate)
        target = str(target)
        if target not in dataset.columns:
            raise Exception('Target is not a column from dataset')
        if (not isinstance(label, str)) or (len(label) == 0):
            raise Exception('Label need to be at least one letter')
        if label in self.list_params('dataset'):
            raise Exception('Labels need to be unique')
        param = DatasetParam(dataset=dataset, label=label, target=target)
        with self.mutex:
            self.update_timestamp()
            self.datasets.append(param)
            self.variables_cache = []
        if precalculate:
            self.fill_cache({'dataset': param})

    def get_params(self, param_type):
        """Returns list of available params

        Parameters
        -----------
        param_type : str
            One of ['model', 'variable', 'observation', 'dataset']. Params of this type
            will be returned

        Notes
        --------
        Information about params https://arena.drwhy.ai/docs/guide/params

        Returns
        --------
        List of Param objects
        """
        if param_type == 'observation':
            with self.mutex:
                return self.observations
        elif param_type == 'variable':
            with self.mutex:
                if not self.variables_cache:
                    # Extract column names from every dataset in self.dataset list and flatten it
                    result_datasets = [col for dataset in self.datasets for col in dataset.variables]
                    # Extract column names from every model in self.models list and flatten it
                    result_explainers = [col for model in self.models for col in model.variables]
                    result_str = np.unique(result_datasets + result_explainers).tolist()
                    self.variables_cache = [VariableParam(x) for x in result_str]
                    if self.enable_attributes:
                        for var in self.variables_cache:
                            try:
                                for dataset in self.datasets:
                                    if var.variable in dataset.variables:
                                        var.update_attributes(dataset.dataset[var.variable])
                                for model in self.models:
                                    if var.variable in model.variables:
                                        var.update_attributes(model.explainer.data[var.variable])
                            except:
                                var.clear_attributes()
                return self.variables_cache
        elif param_type == 'model':
            with self.mutex:
                return self.models
        elif param_type == 'dataset':
            with self.mutex:
                return self.datasets
        else:
            raise Exception('Invalid param type')

    def list_params(self, param_type):
        """Returns list of available params's labels

        Parameters
        -----------
        param_type : str
            One of ['model', 'variable', 'observation', 'dataset']. Labels of params
            of this type will be returned

        Notes
        --------
        Information about params https://arena.drwhy.ai/docs/guide/params

        Returns
        --------
        List of str
        """
        return [x.get_label() for x in self.get_params(param_type)]

    def get_available_params(self):
        """Returns dict containing available params of all types

        This method collect result of `get_params` method for each param type into
        a dict. Keys are param types and values are lists of Param objects.

        Notes
        --------
        Information about params https://arena.drwhy.ai/docs/guide/params

        Returns
        --------
        dict
        """
        result = {}
        for param_type in ['model', 'observation', 'variable', 'dataset']:
            result[param_type] = self.get_params(param_type)
        return result

    def list_available_params(self):
        """Returns dict containing labels of available params of all types

        This methods collect result of `list_params` for each param type into
        a dict. Keys are param types and values are list of labels.

        Notes
        --------
        Information about params https://arena.drwhy.ai/docs/guide/params

        Returns
        --------
        dict
        """
        result = {}
        for param_type in ['model', 'observation', 'variable', 'dataset']:
            result[param_type] = self.list_params(param_type)
        return result

    def find_param_value(self, param_type, param_label):
        """Searches for Param object with specified label

        Parameters
        -----------
        param_type : str
            One of ['model', 'variable', 'observation', 'dataset'].
        param_label : str
            Label of searched param

        Notes
        --------
        Information about params https://arena.drwhy.ai/docs/guide/params

        Returns
        --------
        Param or None
        """
        if param_label is None or not isinstance(param_label, str):
            return None
        return next((x for x in self.get_params(param_type) if x.get_label() == param_label), None)

    def print_options(self, plot_type=None):
        """Prints available options for plots

        Parameters
        -----------
        plot_type : str or None
            When not None, then only options for plots with this plot_type will
            be printed.

        Notes
        --------
        List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level
        """

        plot = next((x for x in self.plots if x.info.get('plotType') == plot_type), None)
        if plot_type is None or plot is None:
            for plot in self.plots:
                self.print_options(plot.info.get('plotType'))
            return
        print('\n\033[1m' + plot.info.get('plotType') + '\033[0m')
        print('---------------------------------')
        for o in plot.options.keys():
            option = plot.options.get(o)
            value = self.options.get(plot_type).get(o)
            print(o + ': ' + str(value) + '   #' + option.get('desc'))

    def get_option(self, plot_type, option):
        """Returns value of specified option

        Parameters
        -----------
        plot_type : str
           Type of plot, the option is corresponding to.
        option : str
            Name of the option

        Notes
        --------
        List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level

        Returns
        --------
        None or value of option
        """
        options = self.options.get(plot_type)
        if options is None:
            raise Exception('Invalid plot_type')
        if not option in options.keys():
            return
        with self.mutex:
            return self.options.get(plot_type).get(option)

    def set_option(self, plot_type, option, value):
        """Sets value for the plot option

        Parameters
        -----------
        plot_type : str
            When None, then value will be set for each plot with
            option of name from `option` argument. Otherwise only
            for plots with specified type.
        option : str
            Name of the option
        value : *
            Value to be set

        Notes
        --------
        List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level
        """
        if plot_type is None:
            for plot in self.plots:
                self.set_option(plot.info.get('plotType'), option, value)
            return
        options = self.options.get(plot_type)
        if options is None:
            raise Exception('Invalid plot_type')
        if not option in options.keys():
            return
        with self.mutex:
            self.options.get(plot_type)[option] = value
            self.clear_cache(plot_type)
        if self.precalculate:
            self.fill_cache()

    def get_plot(self, plot_type, params_values, cache=True):
        """Returns plot for specified type and params

        Function serches for plot in cache, when not present creates
        requested plot and put it to cache.

        Parameters
        -----------
        plot_type : str
            Type of plot to be generated
        params_values : dict
            Dict for param types as keys and Param objects as values
        cache : bool
            If serach for plot in cache and put calculated plot into cache.

        Returns
        --------
        PlotContainer
        """
        plot_class = next((c for c in self.plots if c.info.get('plotType') == plot_type), None)
        if plot_class is None:
            raise Exception('Not supported plot type')
        plot_type = plot_class.info.get('plotType')
        required_params_values = {}
        required_params_labels = {}
        for p in plot_class.info.get('requiredParams'):
            if params_values.get(p) is None:
                raise Exception('Required param is missing')
            required_params_values[p] = params_values.get(p)
            required_params_labels[p] = params_values.get(p).get_label()
        result = self.find_in_cache(plot_type, required_params_labels) if cache else None
        if result is None:
            result = plot_class(self).fit(required_params_values)
            if cache:
                self.put_to_cache(result)
        return result

    def get_params_attributes(self, param_type=None):
        """Returns attributes for all params

        When `param_type` is not None, then function returns list of dicts. Each dict represents
        one of available attribute for specified param type. Field `name` is attribute name
        and field `values` is mapped list of available params to list of value of attribute.
        When `param_type` is None, then function returns dict with keys for each param type and
        values are lists described above.

        Parameters
        -----------
        param_type : str or None
            One of ['model', 'variable', 'observation', 'dataset'] or None. Specifies
            attributes of which params should be returned.

        Notes
        --------
        Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations

        Returns
        --------
        dict or list
        """

        if param_type is None:
            obj = {}
            for p in ['model', 'observation', 'variable', 'dataset']:
                obj[p] = self.get_params_attributes(p)
            return obj
        if not self.enable_attributes:
            return []
        attrs = Param.get_param_class(param_type).list_attributes(self)
        array = []
        for attr in attrs:
            array.append({
                'name': attr,
                'values': [param.get_attributes().get(attr) for param in self.get_params(param_type)]
            })
        return array

    def get_param_attributes(self, param_type, param_label):
        """Returns attributes for one param

        Function searches for param with specified type and label and
        returns it's attributes.

        Parameters
        -----------
        param_type : str
            One of ['model', 'variable', 'observation', 'dataset'].
        param_label : str
            Label of param

        Notes
        --------
        Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations

        Returns
        --------
        dict
        """

        if not self.enable_attributes:
            return {}
        param_value = self.find_param_value(param_type=param_type, param_label=param_label)
        if param_value:
            return param_value.get_attributes()
        else:
            return {}

    def save(self, filename="datasource.json"):
        """Generate all plots and saves them to JSON file

        Function generates only not cached plots.

        Parameters
        -----------
        filename : str
            Path or filename to output file

        Notes
        --------
        Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

        Returns
        --------
        None
        """
        with open(filename, 'w') as file:
            file.write(get_json(self))

    def upload(self, token=None, arena_url='https://arena.drwhy.ai/', open_browser=True):
        """Generate all plots and uploads them to GitHub Gist

        Function generates only not cached plots. If token is not provided
        then function uses OAuth to open GitHub authorization page.

        Parameters
        -----------
        token : str or None
            GitHub personal access token. If token is None, then OAuth is used.
        arena_url : str
            Address of Arena dashboard instance
        open_browser : bool
            Whether to open Arena after upload.

        Notes
        --------
        Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

        Returns
        --------
        Link to the Arena
        """
        global_check_import('requests')
        if token is None:
            global_check_import('flask')
            global_check_import('flask_cors')
            token = generate_token()
        data_url = upload_arena(self, token)
        url = arena_url + '?data=' + data_url
        if open_browser:
            webbrowser.open(url)
        return url

Methods

def clear_cache(self, plot_type=None)

Clears cache

Parameters

plot_type : str or None
If None all cache is cleared. Otherwise only plots with provided plot_type are removed.

Notes

This function must be called from mutex context

Expand source code
def clear_cache(self, plot_type=None):
    """Clears cache

    Parameters
    -----------
    plot_type : str or None
        If None all cache is cleared. Otherwise only plots with
        provided plot_type are removed.

    Notes
    -------
    This function must be called from mutex context
    """
    if plot_type is None:
        self.cache = []
    else:
        self.cache = list(filter(lambda p: p.plot_type != plot_type, self.cache))
    self.update_timestamp()
def fill_cache(self, fixed_params={})

Generates all available plots and cache them

This function tries to generate all plots that are not cached already and put them to cache. Range of generated plots can be narrow using fixed_params

Parameters

fixed_params : dict
This dict specifies which plots should be generated. Only those with all keys from fixed_params present and having the same value will be calculated.
Expand source code
def fill_cache(self, fixed_params={}):
    """Generates all available plots and cache them

    This function tries to generate all plots that are not cached already and
    put them to cache. Range of generated plots can be narrow using `fixed_params`

    Parameters
    -----------
    fixed_params : dict
        This dict specifies which plots should be generated. Only those with
        all keys from `fixed_params` present and having the same value will be
        calculated.
    """
    if not isinstance(fixed_params, dict):
        raise Exception('Params argument must be a dict')
    for plot_class in self.get_supported_plots():
        required_params = plot_class.info.get('requiredParams')
        # Test if all params fixed by user are used in this plot. If not, then skip it.
        # This list contains fixed params' types, that are not required by plot.
        # Loop will be skipped if this list is not empty.
        if len([k for k in fixed_params.keys() if not k in required_params]) > 0:
            continue
        available_params = self.get_available_params()
        iteration_pools = map(lambda p: available_params.get(p) if fixed_params.get(p) is None else [fixed_params.get(p)], required_params)
        combinations = [[]]
        for pool in iteration_pools:
            combinations = [x + [y] for x in combinations for y in pool]
        for params_values in combinations:
            params = dict(zip(required_params, params_values))
            self.get_plot(plot_type=plot_class.info.get('plotType'), params_values=params)
def find_in_cache(self, plot_type, params)

Function searches for cached plot

Parameters

plot_type : str
Value of plot_type field, that requested plot must have
params : dict
Keys of this dict are params types (model, observation, variable, dataset) and values are corresponding params labels. Requested plot must have equal params field.

Returns

PlotContainer or None
 
Expand source code
def find_in_cache(self, plot_type, params):
    """Function searches for cached plot

    Parameters
    -----------
    plot_type : str
        Value of plot_type field, that requested plot must have
    params : dict
        Keys of this dict are params types (model, observation, variable, dataset)
        and values are corresponding params labels. Requested plot must have equal
        params field.

    Returns
    --------
    PlotContainer or None
    """

    _filter = lambda p: p.plot_type == plot_type and params == p.params
    with self.mutex:
        return next(filter(_filter, self.cache), None)
def find_param_value(self, param_type, param_label)

Searches for Param object with specified label

Parameters

param_type : str
One of ['model', 'variable', 'observation', 'dataset'].
param_label : str
Label of searched param

Notes

Information about params https://arena.drwhy.ai/docs/guide/params

Returns

Param or None
 
Expand source code
def find_param_value(self, param_type, param_label):
    """Searches for Param object with specified label

    Parameters
    -----------
    param_type : str
        One of ['model', 'variable', 'observation', 'dataset'].
    param_label : str
        Label of searched param

    Notes
    --------
    Information about params https://arena.drwhy.ai/docs/guide/params

    Returns
    --------
    Param or None
    """
    if param_label is None or not isinstance(param_label, str):
        return None
    return next((x for x in self.get_params(param_type) if x.get_label() == param_label), None)
def get_available_params(self)

Returns dict containing available params of all types

This method collect result of get_params method for each param type into a dict. Keys are param types and values are lists of Param objects.

Notes

Information about params https://arena.drwhy.ai/docs/guide/params

Returns

dict
 
Expand source code
def get_available_params(self):
    """Returns dict containing available params of all types

    This method collect result of `get_params` method for each param type into
    a dict. Keys are param types and values are lists of Param objects.

    Notes
    --------
    Information about params https://arena.drwhy.ai/docs/guide/params

    Returns
    --------
    dict
    """
    result = {}
    for param_type in ['model', 'observation', 'variable', 'dataset']:
        result[param_type] = self.get_params(param_type)
    return result
def get_option(self, plot_type, option)

Returns value of specified option

Parameters

plot_type : str
 
Type of plot, the option is corresponding to.
option : str
Name of the option

Notes

List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level

Returns

None or value of option
 
Expand source code
def get_option(self, plot_type, option):
    """Returns value of specified option

    Parameters
    -----------
    plot_type : str
       Type of plot, the option is corresponding to.
    option : str
        Name of the option

    Notes
    --------
    List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level

    Returns
    --------
    None or value of option
    """
    options = self.options.get(plot_type)
    if options is None:
        raise Exception('Invalid plot_type')
    if not option in options.keys():
        return
    with self.mutex:
        return self.options.get(plot_type).get(option)
def get_param_attributes(self, param_type, param_label)

Returns attributes for one param

Function searches for param with specified type and label and returns it's attributes.

Parameters

param_type : str
One of ['model', 'variable', 'observation', 'dataset'].
param_label : str
Label of param

Notes

Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations

Returns

dict
 
Expand source code
def get_param_attributes(self, param_type, param_label):
    """Returns attributes for one param

    Function searches for param with specified type and label and
    returns it's attributes.

    Parameters
    -----------
    param_type : str
        One of ['model', 'variable', 'observation', 'dataset'].
    param_label : str
        Label of param

    Notes
    --------
    Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations

    Returns
    --------
    dict
    """

    if not self.enable_attributes:
        return {}
    param_value = self.find_param_value(param_type=param_type, param_label=param_label)
    if param_value:
        return param_value.get_attributes()
    else:
        return {}
def get_params(self, param_type)

Returns list of available params

Parameters

param_type : str
One of ['model', 'variable', 'observation', 'dataset']. Params of this type will be returned

Notes

Information about params https://arena.drwhy.ai/docs/guide/params

Returns

List of Param objects
 
Expand source code
def get_params(self, param_type):
    """Returns list of available params

    Parameters
    -----------
    param_type : str
        One of ['model', 'variable', 'observation', 'dataset']. Params of this type
        will be returned

    Notes
    --------
    Information about params https://arena.drwhy.ai/docs/guide/params

    Returns
    --------
    List of Param objects
    """
    if param_type == 'observation':
        with self.mutex:
            return self.observations
    elif param_type == 'variable':
        with self.mutex:
            if not self.variables_cache:
                # Extract column names from every dataset in self.dataset list and flatten it
                result_datasets = [col for dataset in self.datasets for col in dataset.variables]
                # Extract column names from every model in self.models list and flatten it
                result_explainers = [col for model in self.models for col in model.variables]
                result_str = np.unique(result_datasets + result_explainers).tolist()
                self.variables_cache = [VariableParam(x) for x in result_str]
                if self.enable_attributes:
                    for var in self.variables_cache:
                        try:
                            for dataset in self.datasets:
                                if var.variable in dataset.variables:
                                    var.update_attributes(dataset.dataset[var.variable])
                            for model in self.models:
                                if var.variable in model.variables:
                                    var.update_attributes(model.explainer.data[var.variable])
                        except:
                            var.clear_attributes()
            return self.variables_cache
    elif param_type == 'model':
        with self.mutex:
            return self.models
    elif param_type == 'dataset':
        with self.mutex:
            return self.datasets
    else:
        raise Exception('Invalid param type')
def get_params_attributes(self, param_type=None)

Returns attributes for all params

When param_type is not None, then function returns list of dicts. Each dict represents one of available attribute for specified param type. Field name is attribute name and field values is mapped list of available params to list of value of attribute. When param_type is None, then function returns dict with keys for each param type and values are lists described above.

Parameters

param_type : str or None
One of ['model', 'variable', 'observation', 'dataset'] or None. Specifies attributes of which params should be returned.

Notes

Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations

Returns

dict or list
 
Expand source code
def get_params_attributes(self, param_type=None):
    """Returns attributes for all params

    When `param_type` is not None, then function returns list of dicts. Each dict represents
    one of available attribute for specified param type. Field `name` is attribute name
    and field `values` is mapped list of available params to list of value of attribute.
    When `param_type` is None, then function returns dict with keys for each param type and
    values are lists described above.

    Parameters
    -----------
    param_type : str or None
        One of ['model', 'variable', 'observation', 'dataset'] or None. Specifies
        attributes of which params should be returned.

    Notes
    --------
    Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations

    Returns
    --------
    dict or list
    """

    if param_type is None:
        obj = {}
        for p in ['model', 'observation', 'variable', 'dataset']:
            obj[p] = self.get_params_attributes(p)
        return obj
    if not self.enable_attributes:
        return []
    attrs = Param.get_param_class(param_type).list_attributes(self)
    array = []
    for attr in attrs:
        array.append({
            'name': attr,
            'values': [param.get_attributes().get(attr) for param in self.get_params(param_type)]
        })
    return array
def get_plot(self, plot_type, params_values, cache=True)

Returns plot for specified type and params

Function serches for plot in cache, when not present creates requested plot and put it to cache.

Parameters

plot_type : str
Type of plot to be generated
params_values : dict
Dict for param types as keys and Param objects as values
cache : bool
If serach for plot in cache and put calculated plot into cache.

Returns

PlotContainer
 
Expand source code
def get_plot(self, plot_type, params_values, cache=True):
    """Returns plot for specified type and params

    Function serches for plot in cache, when not present creates
    requested plot and put it to cache.

    Parameters
    -----------
    plot_type : str
        Type of plot to be generated
    params_values : dict
        Dict for param types as keys and Param objects as values
    cache : bool
        If serach for plot in cache and put calculated plot into cache.

    Returns
    --------
    PlotContainer
    """
    plot_class = next((c for c in self.plots if c.info.get('plotType') == plot_type), None)
    if plot_class is None:
        raise Exception('Not supported plot type')
    plot_type = plot_class.info.get('plotType')
    required_params_values = {}
    required_params_labels = {}
    for p in plot_class.info.get('requiredParams'):
        if params_values.get(p) is None:
            raise Exception('Required param is missing')
        required_params_values[p] = params_values.get(p)
        required_params_labels[p] = params_values.get(p).get_label()
    result = self.find_in_cache(plot_type, required_params_labels) if cache else None
    if result is None:
        result = plot_class(self).fit(required_params_values)
        if cache:
            self.put_to_cache(result)
    return result
def get_supported_plots(self)

Returns plots classes that can produce at least one valid chart for this Arena.

Returns

List of classes extending PlotContainer
 
Expand source code
def get_supported_plots(self):
    """Returns plots classes that can produce at least one valid chart for this Arena.

    Returns
    -----------
    List of classes extending PlotContainer
    """
    return [plot for plot in self.plots if plot.test_arena(self)]
def list_available_params(self)

Returns dict containing labels of available params of all types

This methods collect result of list_params for each param type into a dict. Keys are param types and values are list of labels.

Notes

Information about params https://arena.drwhy.ai/docs/guide/params

Returns

dict
 
Expand source code
def list_available_params(self):
    """Returns dict containing labels of available params of all types

    This methods collect result of `list_params` for each param type into
    a dict. Keys are param types and values are list of labels.

    Notes
    --------
    Information about params https://arena.drwhy.ai/docs/guide/params

    Returns
    --------
    dict
    """
    result = {}
    for param_type in ['model', 'observation', 'variable', 'dataset']:
        result[param_type] = self.list_params(param_type)
    return result
def list_params(self, param_type)

Returns list of available params's labels

Parameters

param_type : str
One of ['model', 'variable', 'observation', 'dataset']. Labels of params of this type will be returned

Notes

Information about params https://arena.drwhy.ai/docs/guide/params

Returns

List of str
 
Expand source code
def list_params(self, param_type):
    """Returns list of available params's labels

    Parameters
    -----------
    param_type : str
        One of ['model', 'variable', 'observation', 'dataset']. Labels of params
        of this type will be returned

    Notes
    --------
    Information about params https://arena.drwhy.ai/docs/guide/params

    Returns
    --------
    List of str
    """
    return [x.get_label() for x in self.get_params(param_type)]
def print_options(self, plot_type=None)

Prints available options for plots

Parameters

plot_type : str or None
When not None, then only options for plots with this plot_type will be printed.

Notes

List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level

Expand source code
def print_options(self, plot_type=None):
    """Prints available options for plots

    Parameters
    -----------
    plot_type : str or None
        When not None, then only options for plots with this plot_type will
        be printed.

    Notes
    --------
    List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level
    """

    plot = next((x for x in self.plots if x.info.get('plotType') == plot_type), None)
    if plot_type is None or plot is None:
        for plot in self.plots:
            self.print_options(plot.info.get('plotType'))
        return
    print('\n\033[1m' + plot.info.get('plotType') + '\033[0m')
    print('---------------------------------')
    for o in plot.options.keys():
        option = plot.options.get(o)
        value = self.options.get(plot_type).get(o)
        print(o + ': ' + str(value) + '   #' + option.get('desc'))
def push_dataset(self, dataset, target, label, precalculate=None)

Adds dataset to Arena

Pushed dataset will visualised using exploratory data analysis plots. Function creates DatasetParam object with specified label and target name. When precalculation is enabled triggers filling cache.

Parameters

dataset : pandas.DataFrame
Data frame to be visualised using EDA plots. This dataset should contain target variable.
target : str
Name of target column
label : str
Label for this dataset
precalculate : bool or None
Overrides constructor precalculate parameter when it is not None. If true, then only plots using this model will be precalculated.
Expand source code
def push_dataset(self, dataset, target, label, precalculate=None):
    """Adds dataset to Arena

    Pushed dataset will visualised using exploratory data analysis plots.
    Function creates DatasetParam object with specified label and target name.
    When precalculation is enabled triggers filling cache.

    Parameters
    -----------
    dataset : pandas.DataFrame
        Data frame to be visualised using EDA plots. This
        dataset should contain target variable.
    target : str
        Name of target column
    label : str
        Label for this dataset
    precalculate : bool or None
        Overrides constructor `precalculate` parameter when it is not None.
        If true, then only plots using this model will be precalculated.
    """
    if not isinstance(dataset, DataFrame):
        raise Exception('Dataset argument is not a pandas DataFrame')
    if len(dataset.columns.names) != 1:
        raise Exception('Dataset argument need to have only one level column names')
    precalculate = self.precalculate if precalculate is None else bool(precalculate)
    target = str(target)
    if target not in dataset.columns:
        raise Exception('Target is not a column from dataset')
    if (not isinstance(label, str)) or (len(label) == 0):
        raise Exception('Label need to be at least one letter')
    if label in self.list_params('dataset'):
        raise Exception('Labels need to be unique')
    param = DatasetParam(dataset=dataset, label=label, target=target)
    with self.mutex:
        self.update_timestamp()
        self.datasets.append(param)
        self.variables_cache = []
    if precalculate:
        self.fill_cache({'dataset': param})
def push_model(self, explainer, precalculate=None)

Adds model to Arena

This method encapsulate explainer in ModelParam object and save appends models fields. When precalculation is enabled triggers filling cache.

Parameters

explainer : Explainer
Explainer created using dalex package
precalculate : bool or None
Overrides constructor precalculate parameter when it is not None. If true, then only plots using this model will be precalculated.
Expand source code
def push_model(self, explainer, precalculate=None):
    """Adds model to Arena

    This method encapsulate explainer in ModelParam object and
    save appends models fields. When precalculation is enabled
    triggers filling cache.

    Parameters
    -----------
    explainer : dalex.Explainer
        Explainer created using dalex package
    precalculate : bool or None
        Overrides constructor `precalculate` parameter when it is not None.
        If true, then only plots using this model will be precalculated.
    """
    if not isinstance(explainer, Explainer):
        raise Exception('Invalid Explainer argument')
    if explainer.label in self.list_params('model'):
        raise Exception('Explainer with the same label was already added')
    precalculate = self.precalculate if precalculate is None else bool(precalculate)
    param = ModelParam(explainer)
    with self.mutex:
        self.update_timestamp()
        self.models.append(param)
        self.variables_cache = []
    if precalculate:
        self.fill_cache({'model': param})
def push_observations(self, observations, precalculate=None)

Adds observations to Arena

Pushed observations will be used to local explainations. Function creates ObservationParam object for each row of pushed dataset. Label for each observation is taken from row name. When precalculation is enabled triggers filling cache.

Parameters

observations : pandas.DataFrame
Data frame of observations to be explained using instance level plots. Label for each observation is taken from row name.
precalculate : bool or None
Overrides constructor precalculate parameter when it is not None. If true, then only plots using thease observations will be precalculated.
Expand source code
def push_observations(self, observations, precalculate=None):
    """Adds observations to Arena

    Pushed observations will be used to local explainations. Function
    creates ObservationParam object for each row of pushed dataset. Label
    for each observation is taken from row name. When precalculation
    is enabled triggers filling cache.

    Parameters
    -----------
    observations : pandas.DataFrame
        Data frame of observations to be explained using instance level plots.
        Label for each observation is taken from row name.
    precalculate : bool or None
        Overrides constructor `precalculate` parameter when it is not None.
        If true, then only plots using thease observations will be precalculated.
    """
    if not isinstance(observations, DataFrame):
        raise Exception('Observations argument is not a pandas DataFrame')
    if len(observations.index.names) != 1:
        raise Exception('Observations argument need to have only one index')
    if not observations.index.is_unique:
        raise Exception('Observations argument need to have unique indexes')
    precalculate = self.precalculate if precalculate is None else bool(precalculate)
    old_observations = self.list_params('observation')
    observations = observations.set_index(observations.index.astype(str))
    params_objects = []
    for x in observations.index:
        if x in old_observations:
            raise Exception('Indexes of observations need to be unique across all observations')
        params_objects.append(ObservationParam(dataset=observations, index=x))
    with self.mutex:
        self.update_timestamp()
        self.observations.extend(params_objects)
    if precalculate:
        for obs in params_objects:
            self.fill_cache({'observation': obs})
def put_to_cache(self, plot_container)

Puts new plot to cache

Parameters

plot_container : PlotContainer
 
Expand source code
def put_to_cache(self, plot_container):
    """Puts new plot to cache

    Parameters
    -----------
    plot_container : PlotContainer
    """
    if not isinstance(plot_container, PlotContainer):
        raise Exception('Invalid plot container')
    with self.mutex:
        self.cache.append(plot_container)
def run_server(self, host='127.0.0.1', port=8181, append_data=False, arena_url='https://arena.drwhy.ai/', disable_logs=True)

Starts server for live mode of Arena

Parameters

host : str
ip or hostname for the server
port : int
port number for the server
append_data : bool
if generated link should append data to already existing Arena window.
arena_url : str
URl of Arena dhasboard
disable_logs : str
if logs should be muted

Notes

Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

Returns

Link to Arena
 
Expand source code
def run_server(self,
               host='127.0.0.1',
               port=8181,
               append_data=False,
               arena_url='https://arena.drwhy.ai/',
               disable_logs=True):
    """Starts server for live mode of Arena

    Parameters
    -----------
    host : str
        ip or hostname for the server
    port : int
        port number for the server
    append_data : bool
        if generated link should append data to already existing Arena window.
    arena_url : str
        URl of Arena dhasboard
    disable_logs : str
        if logs should be muted

    Notes
    --------
    Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

    Returns
    -----------
    Link to Arena
    """
    if self.server_thread:
        raise Exception('Server is already running. To stop ip use arena.stop_server().')
    global_check_import('flask')
    global_check_import('flask_cors')
    global_check_import('requests')
    self.server_thread = threading.Thread(target=start_server, args=(self, host, port, disable_logs))
    self.server_thread.start()
    if append_data:
        print(arena_url + '?append=http://' + host + ':' + str(port) + '/')
    else:
        print(arena_url + '?data=http://' + host + ':' + str(port) + '/')
def save(self, filename='datasource.json')

Generate all plots and saves them to JSON file

Function generates only not cached plots.

Parameters

filename : str
Path or filename to output file

Notes

Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

Returns

None
 
Expand source code
def save(self, filename="datasource.json"):
    """Generate all plots and saves them to JSON file

    Function generates only not cached plots.

    Parameters
    -----------
    filename : str
        Path or filename to output file

    Notes
    --------
    Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

    Returns
    --------
    None
    """
    with open(filename, 'w') as file:
        file.write(get_json(self))
def set_option(self, plot_type, option, value)

Sets value for the plot option

Parameters

plot_type : str
When None, then value will be set for each plot with option of name from option argument. Otherwise only for plots with specified type.
option : str
Name of the option
value : *
Value to be set

Notes

List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level

Expand source code
def set_option(self, plot_type, option, value):
    """Sets value for the plot option

    Parameters
    -----------
    plot_type : str
        When None, then value will be set for each plot with
        option of name from `option` argument. Otherwise only
        for plots with specified type.
    option : str
        Name of the option
    value : *
        Value to be set

    Notes
    --------
    List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level
    """
    if plot_type is None:
        for plot in self.plots:
            self.set_option(plot.info.get('plotType'), option, value)
        return
    options = self.options.get(plot_type)
    if options is None:
        raise Exception('Invalid plot_type')
    if not option in options.keys():
        return
    with self.mutex:
        self.options.get(plot_type)[option] = value
        self.clear_cache(plot_type)
    if self.precalculate:
        self.fill_cache()
def stop_server(self)

Stops running server

Expand source code
def stop_server(self):
    """Stops running server"""
    if not self.server_thread:
        raise Exception('Server is not running')
    self._stop_server()
    self.server_thread.join()
    self.server_thread = None
def update_timestamp(self)

Updates timestamp

Notes

This function must be called from mutex context

Expand source code
def update_timestamp(self):
    """Updates timestamp

    Notes
    -------
    This function must be called from mutex context
    """
    now = datetime.now()
    self.timestamp = datetime.timestamp(now)
def upload(self, token=None, arena_url='https://arena.drwhy.ai/', open_browser=True)

Generate all plots and uploads them to GitHub Gist

Function generates only not cached plots. If token is not provided then function uses OAuth to open GitHub authorization page.

Parameters

token : str or None
GitHub personal access token. If token is None, then OAuth is used.
arena_url : str
Address of Arena dashboard instance
open_browser : bool
Whether to open Arena after upload.

Notes

Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

Returns

Link to the Arena
 
Expand source code
def upload(self, token=None, arena_url='https://arena.drwhy.ai/', open_browser=True):
    """Generate all plots and uploads them to GitHub Gist

    Function generates only not cached plots. If token is not provided
    then function uses OAuth to open GitHub authorization page.

    Parameters
    -----------
    token : str or None
        GitHub personal access token. If token is None, then OAuth is used.
    arena_url : str
        Address of Arena dashboard instance
    open_browser : bool
        Whether to open Arena after upload.

    Notes
    --------
    Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts

    Returns
    --------
    Link to the Arena
    """
    global_check_import('requests')
    if token is None:
        global_check_import('flask')
        global_check_import('flask_cors')
        token = generate_token()
    data_url = upload_arena(self, token)
    url = arena_url + '?data=' + data_url
    if open_browser:
        webbrowser.open(url)
    return url
class Explainer (model, data=None, y=None, predict_function=None, residual_function=None, weights=None, label=None, model_class=None, verbose=True, precalculate=True, model_type=None, model_info=None)

Create Model Explainer

Black-box models may have very different structures. This class creates a unified representation of a model, which can be further processed by various explanations. Methods of this class produce explanation objects, that contain the main result attribute, and can be visualised using the plot method.

The model is the only required parameter, but most of the explanations require that other parameters are provided (See data, y, predict_function, model_type).

Parameters

model : object
Model to be explained.
data : pd.DataFrame or np.ndarray (2d)
Data which will be used to calculate the explanations. It shouldn't contain the target column (See y). NOTE: If target variable is present in the data, some of the functionalities may not work properly.
y : pd.Series or pd.DataFrame or np.ndarray (1d)
Target variable with outputs / scores. It shall have the same length as data.
predict_function : function, optional
Function that takes two parameters (model, data) and returns a np.ndarray (1d) with model predictions (default is predict method extracted from the model). NOTE: This function needs to work with data as pd.DataFrame.
residual_function : function, optional
Function that takes three parameters (model, data, y) and returns a np.ndarray (1d) with model residuals (default is a function constructed from predict_function).
weights : pd.Series or np.ndarray (1d), optional
Sampling weights for observations in data. It shall have the same length as data (default is None).
label : str, optional
Model name to appear in result and plots (default is last element of the class attribute extracted from the model).
model_class : str, optional
Class of the model that is used e.g. to choose the predict_function (default is the class attribute extracted from the model). NOTE: Use if your model is wrapped with Pipeline.
verbose : bool
Print diagnostic messages during the Explainer initialization (default is True).
precalculate : bool
Calculate y_hat (predicted values) and residuals during the Explainer initialization (default is True).
model_type : {'regression', 'classification', None}
Model task type that is used e.g. in model_performance() and model_parts() (default is try to extract the information from the model, else None).
model_info : dict, optional
Dict {'model_package', 'model_package_version', ...} containing additional information to be stored.

Attributes

model : object
A model to be explained.
data : pd.DataFrame
Data which will be used to calculate the explanations.
y : np.ndarray (1d)
Target variable with outputs / scores.
predict_function : function
Function that takes two arguments (model, data) and returns np.ndarray (1d) with model predictions.
y_hat : np.ndarray (1d)
Model predictions for data.
residual_function : function
Function that takes three arguments (model, data, y) and returns np.ndarray (1d) with model residuals.
residuals : np.ndarray (1d)
Model residuals for data.
weights : np.ndarray (1d)
Sampling weights for observations in data.
label : str
Name to appear in result and plots.
model_class : str
Class of the model.
model_type : {'regression', 'classification',None}
Model task type.
model_info : dict
Dict {'model_package', 'model_package_version', ...} containing additional information.

Notes

Expand source code
class Explainer:
    """ Create Model Explainer

    Black-box models may have very different structures. This class creates a unified
    representation of a model, which can be further processed by various explanations.
    Methods of this class produce explanation objects, that contain the main result
    attribute, and can be visualised using the plot method.

    The `model` is the only required parameter, but most of the explanations require
    that other parameters are provided (See `data`, `y`, `predict_function`, `model_type`).

    Parameters
    ----------
    model : object
        Model to be explained.
    data : pd.DataFrame or np.ndarray (2d)
        Data which will be used to calculate the explanations. It shouldn't contain
        the target column (See `y`).
        NOTE: If target variable is present in the data, some of the functionalities may
        not work properly.
    y : pd.Series or pd.DataFrame or np.ndarray (1d)
        Target variable with outputs / scores. It shall have the same length as `data`.
    predict_function : function, optional
        Function that takes two parameters (model, data) and returns a np.ndarray (1d)
        with model predictions (default is predict method extracted from the model).
        NOTE: This function needs to work with `data` as pd.DataFrame.
    residual_function : function, optional
        Function that takes three parameters (model, data, y) and returns a np.ndarray (1d)
        with model residuals (default is a function constructed from `predict_function`).
    weights : pd.Series or np.ndarray (1d), optional
        Sampling weights for observations in `data`. It shall have the same length as
        `data` (default is `None`).
    label : str, optional
        Model name to appear in result and plots
        (default is last element of the class attribute extracted from the model).
    model_class : str, optional
        Class of the model that is used e.g. to choose the `predict_function`
        (default is the class attribute extracted from the model).
        NOTE: Use if your model is wrapped with Pipeline.
    verbose : bool
        Print diagnostic messages during the Explainer initialization (default is `True`).
    precalculate : bool
        Calculate y_hat (predicted values) and residuals during the Explainer
        initialization (default is `True`).
    model_type : {'regression', 'classification', None}
        Model task type that is used e.g. in `model_performance()` and `model_parts()`
        (default is try to extract the information from the model, else `None`).
    model_info: dict, optional
        Dict `{'model_package', 'model_package_version', ...}` containing additional
        information to be stored.

    Attributes
    --------
    model : object
        A model to be explained.
    data : pd.DataFrame
        Data which will be used to calculate the explanations.
    y : np.ndarray (1d)
        Target variable with outputs / scores.
    predict_function : function
        Function that takes two arguments (model, data) and returns np.ndarray (1d)
        with model predictions.
    y_hat : np.ndarray (1d)
        Model predictions for `data`.
    residual_function : function
        Function that takes three arguments (model, data, y) and returns np.ndarray (1d)
        with model residuals.
    residuals : np.ndarray (1d)
        Model residuals for `data`.
    weights : np.ndarray (1d)
        Sampling weights for observations in `data`.
    label : str
        Name to appear in result and plots.
    model_class : str
        Class of the model.
    model_type : {'regression', 'classification', `None`}
        Model task type.
    model_info: dict
        Dict `{'model_package', 'model_package_version', ...}` containing additional
        information.

    Notes
    --------
    - https://pbiecek.github.io/ema/dataSetsIntro.html#ExplainersTitanicPythonCode

    """

    def __init__(self,
                 model,
                 data=None,
                 y=None,
                 predict_function=None,
                 residual_function=None,
                 weights=None,
                 label=None,
                 model_class=None,
                 verbose=True,
                 precalculate=True,
                 model_type=None,
                 model_info=None):

        # TODO: colorize

        helper.verbose_cat("Preparation of a new explainer is initiated\n", verbose=verbose)

        # REPORT: checks for data
        data, model = checks.check_data(data, model, verbose)

        # REPORT: checks for y
        y = checks.check_y(y, data, verbose)

        # REPORT: checks for weights
        weights = checks.check_weights(weights, data, verbose)

        # REPORT: checks for model_class
        model_class, _model_info = checks.check_model_class(model_class, model, verbose)

        # REPORT: checks for label
        label, _model_info = checks.check_label(label, model_class, _model_info, verbose)

        # REPORT: checks for predict_function and model_type
        # these two are together only because of `yhat_exception_dict`
        predict_function, model_type, y_hat, _model_info = \
            checks.check_predict_function_and_model_type(predict_function, model_type,
                                                         model, data, model_class, _model_info,
                                                         precalculate, verbose)

        # if data is specified then we may test predict_function
        # at this moment we have predict function

        # REPORT: checks for residual_function
        residual_function, residuals, _model_info = checks.check_residual_function(
            residual_function, predict_function, model, data, y, _model_info, precalculate, verbose
        )

        # REPORT: checks for model_info
        _model_info = checks.check_model_info(model_info, _model_info, verbose)

        # READY to create an explainer
        self.model = model
        self.data = data
        self.y = y
        self.predict_function = predict_function
        self.y_hat = y_hat
        self.residual_function = residual_function
        self.residuals = residuals
        self.model_class = model_class
        self.label = label
        self.model_info = _model_info
        self.weights = weights
        self.model_type = model_type

        helper.verbose_cat("\nA new explainer has been created!", verbose=verbose)

    def predict(self, data):
        """Make a prediction

        This function uses the `predict_function` attribute.

        Parameters
        ----------
        data : pd.DataFrame, np.ndarray (2d)
            Data which will be used to make a prediction.

        Returns
        ----------
        np.ndarray (1d)
            Model predictions for given `data`.
        """

        checks.check_method_data(data)

        return self.predict_function(self.model, data)

    def residual(self, data, y):
        """Calculate residuals

        This function uses the `residual_function` attribute.

        Parameters
        -----------
        data : pd.DataFrame
            Data which will be used to calculate residuals.
        y : pd.Series or np.ndarray (1d)
            Target variable which will be used to calculate residuals.

        Returns
        -----------
        np.ndarray (1d)
            Model residuals for given `data` and `y`.
        """

        checks.check_method_data(data)

        return self.residual_function(self.model, data, y)

    def predict_parts(self,
                      new_observation,
                      type=('break_down_interactions', 'break_down', 'shap', 'shap_wrapper'),
                      order=None,
                      interaction_preference=1,
                      path="average",
                      B=25,
                      keep_distributions=False,
                      label=None,
                      processes=1,
                      random_state=None,
                      **kwargs):
        """Calculate predict-level variable attributions as Break Down, Shapley Values or Shap Values

        Parameters
        -----------
        new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
            An observation for which a prediction needs to be explained.
        type : {'break_down_interactions', 'break_down', 'shap', 'shap_wrapper'}
            Type of variable attributions (default is `'break_down_interactions'`).
        order : list of int or str, optional
            Parameter specific for `break_down_interactions` and `break_down`. Use a fixed
            order of variables for attribution calculation. Use integer values  or string
            variable names (default is `None`, which means order by importance).
        interaction_preference : int, optional
            Parameter specific for `break_down_interactions` type. Specify which interactions
            will be present in an explanation. The larger the integer, the more frequently
            interactions will be presented (default is `1`).
        path : list of int, optional
            Parameter specific for `shap`. If specified, then attributions for this path
            will be plotted (default is `'average'`, which plots attribution means for
            `B` random paths).
        B : int, optional
            Parameter specific for `shap`. Number of random paths to calculate
            variable attributions (default is `25`).
        keep_distributions :  bool, optional
            Save the distribution of partial predictions (default is `False`).
        label : str, optional
            Name to appear in result and plots. Overrides default.
        processes : int, optional
            Parameter specific for `shap`. Number of parallel processes to use in calculations.
            Iterated over `B` (default is `1`, which means no parallel computation).
        random_state : int, optional
            Set seed for random number generator (default is random seed).
        kwargs : dict
            Used only for `'shap_wrapper'`. Pass `shap_explainer_type` to specify, which
            Explainer shall be used: `{'TreeExplainer', 'DeepExplainer', 'GradientExplainer',
            'LinearExplainer', 'KernelExplainer'}` (default is `None`, which automatically
            chooses an Explainer to use).
            Also keyword arguments passed to one of the: `shap.TreeExplainer.shap_values,
            shap.DeepExplainer.shap_values, shap.GradientExplainer.shap_values,
            shap.LinearExplainer.shap_values, shap.KernelExplainer.shap_values`.
            See https://github.com/slundberg/shap

        Returns
        -----------
        BreakDown, Shap or ShapWrapper class object
            Explanation object containing the main result attribute and the plot method.
            Object class, its attributes, and the plot method depend on the `type` parameter.

        Notes
        --------
        - https://pbiecek.github.io/ema/breakDown.html
        - https://pbiecek.github.io/ema/iBreakDown.html
        - https://pbiecek.github.io/ema/shapley.html
        - https://github.com/slundberg/shap
        """

        checks.check_data_again(self.data)

        types = ('break_down_interactions', 'break_down', 'shap', 'shap_wrapper')
        _type = checks.check_method_type(type, types)

        if _type == 'break_down_interactions' or _type == 'break_down':
            _predict_parts = BreakDown(
                type=_type,
                keep_distributions=keep_distributions,
                order=order,
                interaction_preference=interaction_preference
            )
        elif _type == 'shap':
            _predict_parts = Shap(
                keep_distributions=keep_distributions,
                path=path,
                B=B,
                processes=processes,
                random_state=random_state
            )
        elif _type == 'shap_wrapper':
            _global_checks.global_check_import('shap', 'SHAP explanations')
            _predict_parts = ShapWrapper('predict_parts')
        else:
            raise TypeError("Wrong type parameter.")

        _predict_parts.fit(self, new_observation, **kwargs)

        if label:
            _predict_parts.result['label'] = label

        return _predict_parts

    def predict_profile(self,
                        new_observation,
                        type=('ceteris_paribus',),
                        y=None,
                        variables=None,
                        grid_points=101,
                        variable_splits=None,
                        variable_splits_type='uniform',
                        variable_splits_with_obs=True,
                        processes=1,
                        label=None,
                        verbose=True):
        """Calculate predict-level variable profiles as Ceteris Paribus

        Parameters
        -----------
        new_observation : pd.DataFrame or np.ndarray or pd.Series
            Observations for which predictions need to be explained.
        type : {'ceteris_paribus', TODO: 'oscilations'}
            Type of variable profiles (default is `'ceteris_paribus'`).
        y : pd.Series or np.ndarray (1d), optional
            Target variable with the same length as `new_observation`.
        variables : str or array_like of str, optional
            Variables for which the profiles will be calculated
            (default is `None`, which means all of the variables).
        grid_points : int, optional
            Maximum number of points for profile calculations (default is `101`).
            NOTE: The final number of points may be lower than `grid_points`,
            eg. if there is not enough unique values for a given variable.
        variable_splits : dict of lists, optional
            Split points for variables e.g. `{'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']}`
            (default is `None`, which means that they will be calculated using one of
            `variable_splits_type` and the `data` attribute).
        variable_splits_type : {'uniform', 'quantiles'}, optional
            Way of calculating `variable_splits`. Set `'quantiles'` for percentiles.
            (default is `'uniform'`, which means uniform grid of points).
        variable_splits_with_obs: bool, optional
            Add variable values of `new_observation` data to the `variable_splits`
            (default is `True`).
        label : str, optional
            Name to appear in result and plots. Overrides default.
        processes : int, optional
            Number of parallel processes to use in calculations. Iterated over `variables`
            (default is `1`, which means no parallel computation).
        verbose : bool, optional
            Print tqdm progress bar (default is `True`).

        Returns
        -----------
        CeterisParibus class object
            Explanation object containing the main result attribute and the plot method.

        Notes
        --------
        - https://pbiecek.github.io/ema/ceterisParibus.html
        """

        checks.check_data_again(self.data)

        types = ('ceteris_paribus',)
        _type = checks.check_method_type(type, types)

        if _type == 'ceteris_paribus':
            _predict_profile = CeterisParibus(
                variables=variables,
                grid_points=grid_points,
                variable_splits=variable_splits,
                variable_splits_type=variable_splits_type,
                variable_splits_with_obs=variable_splits_with_obs,
                processes=processes
            )
        else:
            raise TypeError("Wrong type parameter.")

        _predict_profile.fit(self, new_observation, y, verbose)

        if label:
            _predict_profile.result['_label_'] = label

        return _predict_profile

    def predict_surrogate(self,
                          new_observation,
                          type='lime',
                          **kwargs):
        """Wrapper for surrogate model explanations

        This function uses the lime package to create the model explanation.
        See https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

        Parameters
        -----------
        new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
            An observation for which a prediction needs to be explained.
        type : {'lime'}
            Type of explanation method
            (default is `'lime'`, which uses the lime package to create an explanation).
        kwargs : dict
            Keyword arguments passed to the lime.lime_tabular.LimeTabularExplainer object
            and the LimeTabularExplainer.explain_instance method. Exceptions are:
            `training_data`, `mode`, `data_row` and `predict_fn`. Other parameters:
            https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

        Returns
        -----------
        lime.explanation.Explanation
            Explanation object.

        Notes
        -----------
        - https://github.com/marcotcr/lime
        """

        checks.check_data_again(self.data)

        if type == 'lime':
            _global_checks.global_check_import('lime', 'LIME explanations')
            _new_observation = checks.check_new_observation_lime(new_observation)
            _explanation = utils.create_lime_explanation(self, _new_observation, **kwargs)
        else:
            raise TypeError("Wrong 'type' parameter.")

        return _explanation

    def model_performance(self,
                          model_type=None,
                          cutoff=0.5,
                          label=None):
        """Calculate model-level model performance measures

        Parameters
        -----------
        model_type : {'regression', 'classification', None}
            Model task type that is used to choose the proper performance measures
            (default is `None`, which means try to extract from the `model_type` attribute).
        cutoff : float, optional
            Cutoff for predictions in classification models. Needed for measures like
            recall, precision, acc, f1 (default is `0.5`).
        label : str, optional
            Name to appear in result and plots. Overrides default.

        Returns
        -----------
        ModelPerformance class object
            Explanation object containing the main result attribute and the plot method.

        Notes
        --------
        - https://pbiecek.github.io/ema/modelPerformance.html
        """

        checks.check_data_again(self.data)
        checks.check_y_again(self.y)

        if model_type is None and self.model_type is None:
            raise TypeError("if self.model_type is None, then model_type must be not None")
        elif model_type is None:
            model_type = self.model_type

        _model_performance = ModelPerformance(
            model_type=model_type,
            cutoff=cutoff
        )
        _model_performance.fit(self)

        if label:
            _model_performance.result['label'] = label

        return _model_performance

    def model_parts(self,
                    loss_function=None,
                    type=('variable_importance', 'ratio', 'difference', 'shap_wrapper'),
                    N=1000,
                    B=10,
                    variables=None,
                    variable_groups=None,
                    keep_raw_permutations=True,
                    label=None,
                    processes=1,
                    random_state=None,
                    **kwargs):

        """Calculate model-level variable importance

        Parameters
        -----------
        loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
            If string, then such loss function will be used to assess variable importance
            (default is `'rmse'` or `'1-auc'`, depends on `model_type` attribute).
        type : {'variable_importance', 'ratio', 'difference', 'shap_wrapper'}
            Type of transformation that will be applied to dropout loss.
            (default is `'variable_importance'`, which is Permutational Variable Importance).
        N : int, optional
            Number of observations that will be sampled from the `data` attribute before
            the calculation of variable importance. `None` means all `data` (default is `1000`).
        B : int, optional
            Number of permutation rounds to perform on each variable (default is `10`).
        variables : array_like of str, optional
            Variables for which the importance will be calculated
            (default is `None`, which means all of the variables).
            NOTE: Ignored if `variable_groups` is not `None`.
        variable_groups : dict of lists, optional
            Group the variables to calculate their joint variable importance
            e.g. `{'X': ['x1', 'x2'], 'Y': ['y1', 'y2']}` (default is `None`).
        keep_raw_permutations: bool, optional
            Save results for all permutation rounds (default is `True`).
        label : str, optional
            Name to appear in result and plots. Overrides default.
        processes : int, optional
            Number of parallel processes to use in calculations. Iterated over `B`
            (default is `1`, which means no parallel computation).
        random_state : int, optional
            Set seed for random number generator (default is random seed).
        kwargs : dict
            Used only for 'shap_wrapper'. Pass `shap_explainer_type` to specify, which
            Explainer shall be used: `{'TreeExplainer', 'DeepExplainer', 'GradientExplainer',
            'LinearExplainer', 'KernelExplainer'}`.
            Also keyword arguments passed to one of the: `shap.TreeExplainer.shap_values,
            shap.DeepExplainer.shap_values, shap.GradientExplainer.shap_values,
            shap.LinearExplainer.shap_values, shap.KernelExplainer.shap_values`.
            See https://github.com/slundberg/shap

        Returns
        -----------
        VariableImportance or ShapWrapper class object
            Explanation object containing the main result attribute and the plot method.
            Object class, its attributes, and the plot method depend on the `type` parameter.

        Notes
        --------
        - https://pbiecek.github.io/ema/featureImportance.html
        - https://github.com/slundberg/shap
        """

        checks.check_data_again(self.data)

        types = ('variable_importance', 'ratio', 'difference', 'shap_wrapper')
        aliases = {'permutational': 'variable_importance', 'feature_importance': 'variable_importance'}
        _type = checks.check_method_type(type, types, aliases)

        loss_function = checks.check_method_loss_function(self, loss_function)

        if _type != 'shap_wrapper':
            checks.check_y_again(self.y)

            _model_parts = VariableImportance(
                loss_function=loss_function,
                type=_type,
                N=N,
                B=B,
                variables=variables,
                variable_groups=variable_groups,
                processes=processes,
                random_state=random_state,
                keep_raw_permutations=keep_raw_permutations,
            )
            _model_parts.fit(self)

            if label:
                _model_parts.result['label'] = label

        elif _type == 'shap_wrapper':
            _global_checks.global_check_import('shap', 'SHAP explanations')
            _model_parts = ShapWrapper('model_parts')
            if N is None:
                N = self.data.shape[0]
            else:
                N = min(N, self.data.shape[0])

            sampled_rows = np.random.choice(np.arange(N), N, replace=False)
            sampled_data = self.data.iloc[sampled_rows, :]

            _model_parts.fit(self, sampled_data, **kwargs)
        else:
            raise TypeError("Wrong type parameter");

        return _model_parts

    def model_profile(self,
                      type=('partial', 'accumulated', 'conditional'),
                      N=300,
                      variables=None,
                      variable_type='numerical',
                      groups=None,
                      span=0.25,
                      grid_points=101,
                      variable_splits=None,
                      variable_splits_type='uniform',
                      center=True,
                      label=None,
                      processes=1,
                      random_state=None,
                      verbose=True):

        """Calculate model-level variable profiles as Partial or Accumulated Dependence

        Parameters
        -----------
        type : {'partial', 'accumulated', 'conditional'}
            Type of model profiles
            (default is `'partial'` for Partial Dependence Profiles).
        N : int, optional
            Number of observations that will be sampled from the `data` attribute before
            the calculation of variable profiles. `None` means all `data` (default is `300`).
        variables : str or array_like of str, optional
            Variables for which the profiles will be calculated
            (default is `None`, which means all of the variables).
        variable_type : {'numerical', 'categorical'}
            Calculate the profiles for numerical or categorical variables
            (default is `'numerical'`).
        groups : str or array_like of str, optional
            Names of categorical variables that will be used for profile grouping
            (default is `None`, which means no grouping).
        span : float, optional
            Smoothing coefficient used as sd for gaussian kernel (default is `0.25`).
        grid_points : int, optional
            Maximum number of points for profile calculations (default is `101`).
            NOTE: The final number of points may be lower than `grid_points`,
            e.g. if there is not enough unique values for a given variable.
        variable_splits : dict of lists, optional
            Split points for variables e.g. `{'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']}`
            (default is `None`, which means that they will be distributed uniformly).
        variable_splits_type : {'uniform', 'quantiles'}, optional
            Way of calculating `variable_splits`. Set 'quantiles' for percentiles.
            (default is `'uniform'`, which means uniform grid of points).
        center : bool, optional
            Theoretically Accumulated Profiles start at `0`, but are centered to compare
            them with Partial Dependence Profiles (default is `True`, which means center
            around the average `y_hat` calculated on the data sample).
        label : str, optional
            Name to appear in result and plots. Overrides default.
        processes : int, optional
            Number of parallel processes to use in calculations. Iterated over `variables`
            (default is `1`, which means no parallel computation).
        random_state : int, optional
            Set seed for random number generator (default is random seed).
        verbose : bool, optional
            Print tqdm progress bar (default is `True`).

        Returns
        -----------
        AggregatedProfiles class object
            Explanation object containing the main result attribute and the plot method.

        Notes
        --------
        - https://pbiecek.github.io/ema/partialDependenceProfiles.html
        - https://pbiecek.github.io/ema/accumulatedLocalProfiles.html
        """

        checks.check_data_again(self.data)

        types = ('partial', 'accumulated', 'conditional')
        aliases = {'pdp': 'partial', 'ale': 'accumulated'}
        _type = checks.check_method_type(type, types, aliases)

        if N is None:
            N = self.data.shape[0]
        else:
            N = min(N, self.data.shape[0])

        if random_state is not None:
            np.random.seed(random_state)

        I = np.random.choice(np.arange(N), N, replace=False)

        _ceteris_paribus = CeterisParibus(grid_points=grid_points,
                                         variables=variables,
                                         variable_splits=variable_splits,
                                         variable_splits_type=variable_splits_type,
                                         processes=processes)
        _y = self.y[I] if self.y is not None else self.y
        _ceteris_paribus.fit(self, self.data.iloc[I, :], y=_y, verbose=verbose)

        _model_profile = AggregatedProfiles(
            type=_type,
            variables=variables,
            variable_type=variable_type,
            groups=groups,
            span=span,
            center=center,
            random_state=random_state
        )

        _model_profile.fit(_ceteris_paribus, verbose)

        if label:
            _model_profile.result['_label_'] = label

        return _model_profile

    def model_diagnostics(self,
                          variables=None,
                          label=None):
        """Calculate model-level residuals diagnostics

        Parameters
        -----------
        variables : str or array_like of str, optional
            Variables for which the data will be calculated
            (default is `None`, which means all of the variables).
        label : str, optional
            Name to appear in result and plots. Overrides default.

        Returns
        -----------
        ResidualDiagnostics class object
            Explanation object containing the main result attribute and the plot method.

        Notes
        --------
        - https://pbiecek.github.io/ema/residualDiagnostic.html
        """

        checks.check_data_again(self.data)
        checks.check_y_again(self.y)

        _residual_diagnostics = ResidualDiagnostics(
            variables=variables
        )
        _residual_diagnostics.fit(self)

        if label:
            _residual_diagnostics.result['label'] = label

        return _residual_diagnostics

    def model_surrogate(self,
                        type=('tree', 'linear'),
                        max_vars=5,
                        max_depth=3,
                        **kwargs):
        """Create a surrogate interpretable model from the black-box model

        This method uses the scikit-learn package to create a surrogate
        interpretable model (e.g. decision tree) from the black-box model.
        It aims to use the most important features and add a plot method to
        the model, so that it can be easily interpreted. See Notes section
        for references.

        Parameters
        -----------
        type : {'tree', 'linear'}
            Type of a surrogate model. This can be a decision tree or a linear model
            (default is `'tree'`).
        max_vars : int, optional
            Maximum number of variables that will be used in surrogate model training.
            These are the most important variables to the black-box model (default is `5`).
        max_depth : int, optional
            The maximum depth of the tree. If `None`, then nodes are expanded until all
            leaves are pure or until all leaves contain less than min_samples_split
            samples (default is `3` for interpretable plot).
        kwargs : dict
            Keyword arguments passed to one of the: `sklearn.tree.DecisionTreeClassifier,
            sklearn.tree.DecisionTreeRegressor, sklearn.linear_model.LogisticRegression,
            sklearn.linear_model.LinearRegression`


        Returns
        -----------
        One of: sklearn.tree.DecisionTreeClassifier, sklearn.tree.DecisionTreeRegressor, sklearn.linear_model.LogisticRegression, sklearn.linear_model.LinearRegression
        A surrogate model with additional:
            - `plot` method
            - `performance` attribute
            - `feature_names` attribute
            - `class_names` attribute

        Notes
        -----------
        - https://christophm.github.io/interpretable-ml-book/global.html
        - https://github.com/scikit-learn/scikit-learn
        """

        _global_checks.global_check_import('scikit-learn', 'surrogate models')
        checks.check_data_again(self.data)

        types = ('tree', 'linear')
        _type = checks.check_method_type(type, types)

        surrogate_model = utils.create_surrogate_model(explainer=self,
                                                       type=_type,
                                                       max_vars=max_vars,
                                                       max_depth=max_depth,
                                                       **kwargs)

        return surrogate_model

    def model_fairness(self,
                       protected,
                       privileged,
                       cutoff=0.5,
                       label=None,
                       **kwargs):
        """Creates a model-level fairness explanation that enables bias detection

        This method returns a GroupFairnessClassification object that for now
        supports only classification models. GroupFairnessClassification object
        works as a wrapper of the protected attribute and the Explainer from which
        `y` and `y_hat` attributes were extracted. Along with an information about
        privileged subgroup (value in the `protected` parameter), those 3 vectors
        create triplet `(y, y_hat, protected)` which is a base for all further
        fairness calculations and visualizations.

        Parameters
        -----------
        protected : np.ndarray (1d)
            Vector, preferably 1-dimensional np.ndarray containing strings,
            which denotes the membership to a subgroup. It doesn't have to be binary.
            It doesn't need to be in data. It is sometimes suggested not to use
            sensitive attributes in modelling, but still check model bias for them.
            NOTE: List and pd.Series are also supported, however if provided
            they will be transformed into a (1d) np.ndarray with dtype 'U'.
        privileged : str
            Subgroup that is suspected to have the most privilege.
            It needs to be a string present in `protected`.
        cutoff : float or dict, optional
            Threshold for probabilistic output of a classifier.
            It might be: a `float` - same for all subgroups from `protected`,
            or a `dict` - individually adjusted for each subgroup
            (must have values from `protected` as keys).
        label : str, optional
            Name to appear in result and plots. Overrides default.
        kwargs : dict
            Keyword arguments. It supports `verbose`, which is a boolean
            value telling if additional output should be printed
            (`True`) or not (`False`, default).

        Returns
        -----------
        GroupFairnessClassification class object (a subclass of _FairnessObject)
            Explanation object containing the main result attribute and the plot method.

        It has the following main attributes:
            - result : `pd.DataFrame`
                Scaled `metric_scores`. The scaling is performed by
                dividing all metric scores by scores of the privileged subgroup.
            - metric_scores : `pd.DataFrame`
                Raw metric scores for each subgroup.
            - parity_loss : `pd.Series`
                It is a summarised `result`. From each metric (column) a logarithm
                is calculated, then the absolute value is taken and summarised.
                Therefore, for metric M:
                    `parity_loss` is a `sum(abs(log(M_i / M_privileged)))`
                        where `M_i` is the metric score for subgroup `i`.
            - label : `str`
                `label` attribute from the Explainer object.
                    Labels must be unique when plotting.
            - cutoff : `dict`
                A float value for each subgroup (key in dict).

        Notes
        -----------
        - Verma, S. & Rubin, J. (2018) https://fairware.cs.umass.edu/papers/Verma.pdf
        - Zafar, M.B., et al. (2017) https://arxiv.org/pdf/1610.08452.pdf
        - Hardt, M., et al. (2016) https://arxiv.org/pdf/1610.02413.pdf
        """

        if self.model_type != 'classification':
            raise ValueError(
                "fairness module for now supports only classification models."
                "Explainer attribute 'model_type' must be 'classification'")

        fobject = GroupFairnessClassification(y=self.y,
                                              y_hat=self.y_hat,
                                              protected=protected,
                                              privileged=privileged,
                                              cutoff=cutoff,
                                              label=self.label,
                                              **kwargs)

        if label:
             fobject.label = label

        return fobject

    def dumps(self, *args, **kwargs):
        """Return the pickled representation (bytes object) of the Explainer

        This method uses the pickle package. See
        https://docs.python.org/3/library/pickle.html#pickle.dumps

        NOTE: local functions and lambdas cannot be pickled.
        Attribute `residual_function` by default contains lambda; thus,
        if not provided by the user, it will be dropped before the dump.

        Parameters
        -----------
        args : dict
            Positional arguments passed to the pickle.dumps function.
        kwargs : dict
            Keyword arguments passed to the pickle.dumps function.

        Returns
        -----------
        bytes object
        """

        from copy import deepcopy
        to_dump = deepcopy(self)
        to_dump = checks.check_if_local_and_lambda(to_dump)

        import pickle
        return pickle.dumps(to_dump, *args, **kwargs)

    def dump(self, file, *args, **kwargs):
        """Write the pickled representation of the Explainer to the file (pickle)

        This method uses the pickle package. See
        https://docs.python.org/3/library/pickle.html#pickle.dump

        NOTE: local functions and lambdas cannot be pickled.
        Attribute `residual_function` by default contains lambda; thus,
        if not provided by the user, it will be dropped before the dump.

        Parameters
        -----------
        file : ...
            A file object opened for binary writing, or an io.BytesIO instance.
        args : dict
            Positional arguments passed to the pickle.dump function.
        kwargs : dict
            Keyword arguments passed to the pickle.dump function.
        """

        from copy import deepcopy
        to_dump = deepcopy(self)
        to_dump = checks.check_if_local_and_lambda(to_dump)

        import pickle
        return pickle.dump(to_dump, file, *args, **kwargs)

    @staticmethod
    def loads(data, use_defaults=True, *args, **kwargs):
        """Load the Explainer from the pickled representation (bytes object)

        This method uses the pickle package. See
        https://docs.python.org/3/library/pickle.html#pickle.loads

        NOTE: local functions and lambdas cannot be pickled.
        If `use_defaults` is set to `True`, then dropped functions are set to defaults.

        Parameters
        -----------
        data : bytes object
            Binary representation of the Explainer.
        use_defaults : bool
            Replace empty `predict_function` and `residual_function` with default
            values like in Explainer initialization (default is `True`).
        args : dict
            Positional arguments passed to the pickle.loads function.
        kwargs : dict
            Keyword arguments passed to the pickle.loads function.

        Returns
        -----------
        Explainer object
        """

        import pickle
        exp = pickle.loads(data, *args, **kwargs)

        if use_defaults:
            exp = checks.check_if_empty_fields(exp)

        return exp

    @staticmethod
    def load(file, use_defaults=True, *args, **kwargs):
        """Read the pickled representation of the Explainer from the file (pickle)

        This method uses the pickle package. See
        https://docs.python.org/3/library/pickle.html#pickle.load

        NOTE: local functions and lambdas cannot be pickled.
        If `use_defaults` is set to `True`, then dropped functions are set to defaults.

        Parameters
        -----------
        file : ...
            A binary file object opened for reading, or an io.BytesIO object.
        use_defaults : bool
            Replace empty `predict_function` and `residual_function` with default
            values like in Explainer initialization (default is `True`).
        args : dict
            Positional arguments passed to the pickle.load function.
        kwargs : dict
            Keyword arguments passed to the pickle.load function.

        Returns
        -----------
        Explainer object
        """

        import pickle
        exp = pickle.load(file, *args, **kwargs)

        if use_defaults:
            exp = checks.check_if_empty_fields(exp)

        return exp

Static methods

def load(file, use_defaults=True, *args, **kwargs)

Read the pickled representation of the Explainer from the file (pickle)

This method uses the pickle package. See https://docs.python.org/3/library/pickle.html#pickle.load

NOTE: local functions and lambdas cannot be pickled. If use_defaults is set to True, then dropped functions are set to defaults.

Parameters

file : …
A binary file object opened for reading, or an io.BytesIO object.
use_defaults : bool
Replace empty predict_function and residual_function with default values like in Explainer initialization (default is True).
args : dict
Positional arguments passed to the pickle.load function.
kwargs : dict
Keyword arguments passed to the pickle.load function.

Returns

Explainer object
 
Expand source code
@staticmethod
def load(file, use_defaults=True, *args, **kwargs):
    """Read the pickled representation of the Explainer from the file (pickle)

    This method uses the pickle package. See
    https://docs.python.org/3/library/pickle.html#pickle.load

    NOTE: local functions and lambdas cannot be pickled.
    If `use_defaults` is set to `True`, then dropped functions are set to defaults.

    Parameters
    -----------
    file : ...
        A binary file object opened for reading, or an io.BytesIO object.
    use_defaults : bool
        Replace empty `predict_function` and `residual_function` with default
        values like in Explainer initialization (default is `True`).
    args : dict
        Positional arguments passed to the pickle.load function.
    kwargs : dict
        Keyword arguments passed to the pickle.load function.

    Returns
    -----------
    Explainer object
    """

    import pickle
    exp = pickle.load(file, *args, **kwargs)

    if use_defaults:
        exp = checks.check_if_empty_fields(exp)

    return exp
def loads(data, use_defaults=True, *args, **kwargs)

Load the Explainer from the pickled representation (bytes object)

This method uses the pickle package. See https://docs.python.org/3/library/pickle.html#pickle.loads

NOTE: local functions and lambdas cannot be pickled. If use_defaults is set to True, then dropped functions are set to defaults.

Parameters

data : bytes object
Binary representation of the Explainer.
use_defaults : bool
Replace empty predict_function and residual_function with default values like in Explainer initialization (default is True).
args : dict
Positional arguments passed to the pickle.loads function.
kwargs : dict
Keyword arguments passed to the pickle.loads function.

Returns

Explainer object
 
Expand source code
@staticmethod
def loads(data, use_defaults=True, *args, **kwargs):
    """Load the Explainer from the pickled representation (bytes object)

    This method uses the pickle package. See
    https://docs.python.org/3/library/pickle.html#pickle.loads

    NOTE: local functions and lambdas cannot be pickled.
    If `use_defaults` is set to `True`, then dropped functions are set to defaults.

    Parameters
    -----------
    data : bytes object
        Binary representation of the Explainer.
    use_defaults : bool
        Replace empty `predict_function` and `residual_function` with default
        values like in Explainer initialization (default is `True`).
    args : dict
        Positional arguments passed to the pickle.loads function.
    kwargs : dict
        Keyword arguments passed to the pickle.loads function.

    Returns
    -----------
    Explainer object
    """

    import pickle
    exp = pickle.loads(data, *args, **kwargs)

    if use_defaults:
        exp = checks.check_if_empty_fields(exp)

    return exp

Methods

def dump(self, file, *args, **kwargs)

Write the pickled representation of the Explainer to the file (pickle)

This method uses the pickle package. See https://docs.python.org/3/library/pickle.html#pickle.dump

NOTE: local functions and lambdas cannot be pickled. Attribute residual_function by default contains lambda; thus, if not provided by the user, it will be dropped before the dump.

Parameters

file : …
A file object opened for binary writing, or an io.BytesIO instance.
args : dict
Positional arguments passed to the pickle.dump function.
kwargs : dict
Keyword arguments passed to the pickle.dump function.
Expand source code
def dump(self, file, *args, **kwargs):
    """Write the pickled representation of the Explainer to the file (pickle)

    This method uses the pickle package. See
    https://docs.python.org/3/library/pickle.html#pickle.dump

    NOTE: local functions and lambdas cannot be pickled.
    Attribute `residual_function` by default contains lambda; thus,
    if not provided by the user, it will be dropped before the dump.

    Parameters
    -----------
    file : ...
        A file object opened for binary writing, or an io.BytesIO instance.
    args : dict
        Positional arguments passed to the pickle.dump function.
    kwargs : dict
        Keyword arguments passed to the pickle.dump function.
    """

    from copy import deepcopy
    to_dump = deepcopy(self)
    to_dump = checks.check_if_local_and_lambda(to_dump)

    import pickle
    return pickle.dump(to_dump, file, *args, **kwargs)
def dumps(self, *args, **kwargs)

Return the pickled representation (bytes object) of the Explainer

This method uses the pickle package. See https://docs.python.org/3/library/pickle.html#pickle.dumps

NOTE: local functions and lambdas cannot be pickled. Attribute residual_function by default contains lambda; thus, if not provided by the user, it will be dropped before the dump.

Parameters

args : dict
Positional arguments passed to the pickle.dumps function.
kwargs : dict
Keyword arguments passed to the pickle.dumps function.

Returns

bytes object
 
Expand source code
def dumps(self, *args, **kwargs):
    """Return the pickled representation (bytes object) of the Explainer

    This method uses the pickle package. See
    https://docs.python.org/3/library/pickle.html#pickle.dumps

    NOTE: local functions and lambdas cannot be pickled.
    Attribute `residual_function` by default contains lambda; thus,
    if not provided by the user, it will be dropped before the dump.

    Parameters
    -----------
    args : dict
        Positional arguments passed to the pickle.dumps function.
    kwargs : dict
        Keyword arguments passed to the pickle.dumps function.

    Returns
    -----------
    bytes object
    """

    from copy import deepcopy
    to_dump = deepcopy(self)
    to_dump = checks.check_if_local_and_lambda(to_dump)

    import pickle
    return pickle.dumps(to_dump, *args, **kwargs)
def model_diagnostics(self, variables=None, label=None)

Calculate model-level residuals diagnostics

Parameters

variables : str or array_like of str, optional
Variables for which the data will be calculated (default is None, which means all of the variables).
label : str, optional
Name to appear in result and plots. Overrides default.

Returns

ResidualDiagnostics class object
Explanation object containing the main result attribute and the plot method.

Notes

Expand source code
def model_diagnostics(self,
                      variables=None,
                      label=None):
    """Calculate model-level residuals diagnostics

    Parameters
    -----------
    variables : str or array_like of str, optional
        Variables for which the data will be calculated
        (default is `None`, which means all of the variables).
    label : str, optional
        Name to appear in result and plots. Overrides default.

    Returns
    -----------
    ResidualDiagnostics class object
        Explanation object containing the main result attribute and the plot method.

    Notes
    --------
    - https://pbiecek.github.io/ema/residualDiagnostic.html
    """

    checks.check_data_again(self.data)
    checks.check_y_again(self.y)

    _residual_diagnostics = ResidualDiagnostics(
        variables=variables
    )
    _residual_diagnostics.fit(self)

    if label:
        _residual_diagnostics.result['label'] = label

    return _residual_diagnostics
def model_fairness(self, protected, privileged, cutoff=0.5, label=None, **kwargs)

Creates a model-level fairness explanation that enables bias detection

This method returns a GroupFairnessClassification object that for now supports only classification models. GroupFairnessClassification object works as a wrapper of the protected attribute and the Explainer from which y and y_hat attributes were extracted. Along with an information about privileged subgroup (value in the protected parameter), those 3 vectors create triplet (y, y_hat, protected) which is a base for all further fairness calculations and visualizations.

Parameters

protected : np.ndarray (1d)
Vector, preferably 1-dimensional np.ndarray containing strings, which denotes the membership to a subgroup. It doesn't have to be binary. It doesn't need to be in data. It is sometimes suggested not to use sensitive attributes in modelling, but still check model bias for them. NOTE: List and pd.Series are also supported, however if provided they will be transformed into a (1d) np.ndarray with dtype 'U'.
privileged : str
Subgroup that is suspected to have the most privilege. It needs to be a string present in protected.
cutoff : float or dict, optional
Threshold for probabilistic output of a classifier. It might be: a float - same for all subgroups from protected, or a dict - individually adjusted for each subgroup (must have values from protected as keys).
label : str, optional
Name to appear in result and plots. Overrides default.
kwargs : dict
Keyword arguments. It supports verbose, which is a boolean value telling if additional output should be printed (True) or not (False, default).

Returns

GroupFairnessClassification class object (a subclass of _FairnessObject)
Explanation object containing the main result attribute and the plot method.
It has the following main attributes:
  • result : pd.DataFrame Scaled metric_scores. The scaling is performed by dividing all metric scores by scores of the privileged subgroup.
  • metric_scores : pd.DataFrame Raw metric scores for each subgroup.
  • parity_loss : pd.Series It is a summarised result. From each metric (column) a logarithm is calculated, then the absolute value is taken and summarised. Therefore, for metric M: parity_loss is a sum(abs(log(M_i / M_privileged))) where M_i is the metric score for subgroup i.
  • label : str label attribute from the Explainer object. Labels must be unique when plotting.
  • cutoff : dict A float value for each subgroup (key in dict).

Notes

Expand source code
def model_fairness(self,
                   protected,
                   privileged,
                   cutoff=0.5,
                   label=None,
                   **kwargs):
    """Creates a model-level fairness explanation that enables bias detection

    This method returns a GroupFairnessClassification object that for now
    supports only classification models. GroupFairnessClassification object
    works as a wrapper of the protected attribute and the Explainer from which
    `y` and `y_hat` attributes were extracted. Along with an information about
    privileged subgroup (value in the `protected` parameter), those 3 vectors
    create triplet `(y, y_hat, protected)` which is a base for all further
    fairness calculations and visualizations.

    Parameters
    -----------
    protected : np.ndarray (1d)
        Vector, preferably 1-dimensional np.ndarray containing strings,
        which denotes the membership to a subgroup. It doesn't have to be binary.
        It doesn't need to be in data. It is sometimes suggested not to use
        sensitive attributes in modelling, but still check model bias for them.
        NOTE: List and pd.Series are also supported, however if provided
        they will be transformed into a (1d) np.ndarray with dtype 'U'.
    privileged : str
        Subgroup that is suspected to have the most privilege.
        It needs to be a string present in `protected`.
    cutoff : float or dict, optional
        Threshold for probabilistic output of a classifier.
        It might be: a `float` - same for all subgroups from `protected`,
        or a `dict` - individually adjusted for each subgroup
        (must have values from `protected` as keys).
    label : str, optional
        Name to appear in result and plots. Overrides default.
    kwargs : dict
        Keyword arguments. It supports `verbose`, which is a boolean
        value telling if additional output should be printed
        (`True`) or not (`False`, default).

    Returns
    -----------
    GroupFairnessClassification class object (a subclass of _FairnessObject)
        Explanation object containing the main result attribute and the plot method.

    It has the following main attributes:
        - result : `pd.DataFrame`
            Scaled `metric_scores`. The scaling is performed by
            dividing all metric scores by scores of the privileged subgroup.
        - metric_scores : `pd.DataFrame`
            Raw metric scores for each subgroup.
        - parity_loss : `pd.Series`
            It is a summarised `result`. From each metric (column) a logarithm
            is calculated, then the absolute value is taken and summarised.
            Therefore, for metric M:
                `parity_loss` is a `sum(abs(log(M_i / M_privileged)))`
                    where `M_i` is the metric score for subgroup `i`.
        - label : `str`
            `label` attribute from the Explainer object.
                Labels must be unique when plotting.
        - cutoff : `dict`
            A float value for each subgroup (key in dict).

    Notes
    -----------
    - Verma, S. & Rubin, J. (2018) https://fairware.cs.umass.edu/papers/Verma.pdf
    - Zafar, M.B., et al. (2017) https://arxiv.org/pdf/1610.08452.pdf
    - Hardt, M., et al. (2016) https://arxiv.org/pdf/1610.02413.pdf
    """

    if self.model_type != 'classification':
        raise ValueError(
            "fairness module for now supports only classification models."
            "Explainer attribute 'model_type' must be 'classification'")

    fobject = GroupFairnessClassification(y=self.y,
                                          y_hat=self.y_hat,
                                          protected=protected,
                                          privileged=privileged,
                                          cutoff=cutoff,
                                          label=self.label,
                                          **kwargs)

    if label:
         fobject.label = label

    return fobject
def model_parts(self, loss_function=None, type=('variable_importance', 'ratio', 'difference', 'shap_wrapper'), N=1000, B=10, variables=None, variable_groups=None, keep_raw_permutations=True, label=None, processes=1, random_state=None, **kwargs)

Calculate model-level variable importance

Parameters

loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
If string, then such loss function will be used to assess variable importance (default is 'rmse' or '1-auc', depends on model_type attribute).
type : {'variable_importance', 'ratio', 'difference', 'shap_wrapper'}
Type of transformation that will be applied to dropout loss. (default is 'variable_importance', which is Permutational Variable Importance).
N : int, optional
Number of observations that will be sampled from the data attribute before the calculation of variable importance. None means all data (default is 1000).
B : int, optional
Number of permutation rounds to perform on each variable (default is 10).
variables : array_like of str, optional
Variables for which the importance will be calculated (default is None, which means all of the variables). NOTE: Ignored if variable_groups is not None.
variable_groups : dict of lists, optional
Group the variables to calculate their joint variable importance e.g. {'X': ['x1', 'x2'], 'Y': ['y1', 'y2']} (default is None).
keep_raw_permutations : bool, optional
Save results for all permutation rounds (default is True).
label : str, optional
Name to appear in result and plots. Overrides default.
processes : int, optional
Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).
kwargs : dict
Used only for 'shap_wrapper'. Pass shap_explainer_type to specify, which Explainer shall be used: {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'}. Also keyword arguments passed to one of the: shap.TreeExplainer.shap_values, shap.DeepExplainer.shap_values, shap.GradientExplainer.shap_values, shap.LinearExplainer.shap_values, shap.KernelExplainer.shap_values. See https://github.com/slundberg/shap

Returns

VariableImportance or ShapWrapper class object
Explanation object containing the main result attribute and the plot method. Object class, its attributes, and the plot method depend on the type parameter.

Notes

Expand source code
def model_parts(self,
                loss_function=None,
                type=('variable_importance', 'ratio', 'difference', 'shap_wrapper'),
                N=1000,
                B=10,
                variables=None,
                variable_groups=None,
                keep_raw_permutations=True,
                label=None,
                processes=1,
                random_state=None,
                **kwargs):

    """Calculate model-level variable importance

    Parameters
    -----------
    loss_function : {'rmse', '1-auc', 'mse', 'mae', 'mad'} or function, optional
        If string, then such loss function will be used to assess variable importance
        (default is `'rmse'` or `'1-auc'`, depends on `model_type` attribute).
    type : {'variable_importance', 'ratio', 'difference', 'shap_wrapper'}
        Type of transformation that will be applied to dropout loss.
        (default is `'variable_importance'`, which is Permutational Variable Importance).
    N : int, optional
        Number of observations that will be sampled from the `data` attribute before
        the calculation of variable importance. `None` means all `data` (default is `1000`).
    B : int, optional
        Number of permutation rounds to perform on each variable (default is `10`).
    variables : array_like of str, optional
        Variables for which the importance will be calculated
        (default is `None`, which means all of the variables).
        NOTE: Ignored if `variable_groups` is not `None`.
    variable_groups : dict of lists, optional
        Group the variables to calculate their joint variable importance
        e.g. `{'X': ['x1', 'x2'], 'Y': ['y1', 'y2']}` (default is `None`).
    keep_raw_permutations: bool, optional
        Save results for all permutation rounds (default is `True`).
    label : str, optional
        Name to appear in result and plots. Overrides default.
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `B`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).
    kwargs : dict
        Used only for 'shap_wrapper'. Pass `shap_explainer_type` to specify, which
        Explainer shall be used: `{'TreeExplainer', 'DeepExplainer', 'GradientExplainer',
        'LinearExplainer', 'KernelExplainer'}`.
        Also keyword arguments passed to one of the: `shap.TreeExplainer.shap_values,
        shap.DeepExplainer.shap_values, shap.GradientExplainer.shap_values,
        shap.LinearExplainer.shap_values, shap.KernelExplainer.shap_values`.
        See https://github.com/slundberg/shap

    Returns
    -----------
    VariableImportance or ShapWrapper class object
        Explanation object containing the main result attribute and the plot method.
        Object class, its attributes, and the plot method depend on the `type` parameter.

    Notes
    --------
    - https://pbiecek.github.io/ema/featureImportance.html
    - https://github.com/slundberg/shap
    """

    checks.check_data_again(self.data)

    types = ('variable_importance', 'ratio', 'difference', 'shap_wrapper')
    aliases = {'permutational': 'variable_importance', 'feature_importance': 'variable_importance'}
    _type = checks.check_method_type(type, types, aliases)

    loss_function = checks.check_method_loss_function(self, loss_function)

    if _type != 'shap_wrapper':
        checks.check_y_again(self.y)

        _model_parts = VariableImportance(
            loss_function=loss_function,
            type=_type,
            N=N,
            B=B,
            variables=variables,
            variable_groups=variable_groups,
            processes=processes,
            random_state=random_state,
            keep_raw_permutations=keep_raw_permutations,
        )
        _model_parts.fit(self)

        if label:
            _model_parts.result['label'] = label

    elif _type == 'shap_wrapper':
        _global_checks.global_check_import('shap', 'SHAP explanations')
        _model_parts = ShapWrapper('model_parts')
        if N is None:
            N = self.data.shape[0]
        else:
            N = min(N, self.data.shape[0])

        sampled_rows = np.random.choice(np.arange(N), N, replace=False)
        sampled_data = self.data.iloc[sampled_rows, :]

        _model_parts.fit(self, sampled_data, **kwargs)
    else:
        raise TypeError("Wrong type parameter");

    return _model_parts
def model_performance(self, model_type=None, cutoff=0.5, label=None)

Calculate model-level model performance measures

Parameters

model_type : {'regression', 'classification', None}
Model task type that is used to choose the proper performance measures (default is None, which means try to extract from the model_type attribute).
cutoff : float, optional
Cutoff for predictions in classification models. Needed for measures like recall, precision, acc, f1 (default is 0.5).
label : str, optional
Name to appear in result and plots. Overrides default.

Returns

ModelPerformance class object
Explanation object containing the main result attribute and the plot method.

Notes

Expand source code
def model_performance(self,
                      model_type=None,
                      cutoff=0.5,
                      label=None):
    """Calculate model-level model performance measures

    Parameters
    -----------
    model_type : {'regression', 'classification', None}
        Model task type that is used to choose the proper performance measures
        (default is `None`, which means try to extract from the `model_type` attribute).
    cutoff : float, optional
        Cutoff for predictions in classification models. Needed for measures like
        recall, precision, acc, f1 (default is `0.5`).
    label : str, optional
        Name to appear in result and plots. Overrides default.

    Returns
    -----------
    ModelPerformance class object
        Explanation object containing the main result attribute and the plot method.

    Notes
    --------
    - https://pbiecek.github.io/ema/modelPerformance.html
    """

    checks.check_data_again(self.data)
    checks.check_y_again(self.y)

    if model_type is None and self.model_type is None:
        raise TypeError("if self.model_type is None, then model_type must be not None")
    elif model_type is None:
        model_type = self.model_type

    _model_performance = ModelPerformance(
        model_type=model_type,
        cutoff=cutoff
    )
    _model_performance.fit(self)

    if label:
        _model_performance.result['label'] = label

    return _model_performance
def model_profile(self, type=('partial', 'accumulated', 'conditional'), N=300, variables=None, variable_type='numerical', groups=None, span=0.25, grid_points=101, variable_splits=None, variable_splits_type='uniform', center=True, label=None, processes=1, random_state=None, verbose=True)

Calculate model-level variable profiles as Partial or Accumulated Dependence

Parameters

type : {'partial', 'accumulated', 'conditional'}
Type of model profiles (default is 'partial' for Partial Dependence Profiles).
N : int, optional
Number of observations that will be sampled from the data attribute before the calculation of variable profiles. None means all data (default is 300).
variables : str or array_like of str, optional
Variables for which the profiles will be calculated (default is None, which means all of the variables).
variable_type : {'numerical', 'categorical'}
Calculate the profiles for numerical or categorical variables (default is 'numerical').
groups : str or array_like of str, optional
Names of categorical variables that will be used for profile grouping (default is None, which means no grouping).
span : float, optional
Smoothing coefficient used as sd for gaussian kernel (default is 0.25).
grid_points : int, optional
Maximum number of points for profile calculations (default is 101). NOTE: The final number of points may be lower than grid_points, e.g. if there is not enough unique values for a given variable.
variable_splits : dict of lists, optional
Split points for variables e.g. {'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']} (default is None, which means that they will be distributed uniformly).
variable_splits_type : {'uniform', 'quantiles'}, optional
Way of calculating variable_splits. Set 'quantiles' for percentiles. (default is 'uniform', which means uniform grid of points).
center : bool, optional
Theoretically Accumulated Profiles start at 0, but are centered to compare them with Partial Dependence Profiles (default is True, which means center around the average y_hat calculated on the data sample).
label : str, optional
Name to appear in result and plots. Overrides default.
processes : int, optional
Number of parallel processes to use in calculations. Iterated over variables (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).
verbose : bool, optional
Print tqdm progress bar (default is True).

Returns

AggregatedProfiles class object
Explanation object containing the main result attribute and the plot method.

Notes

Expand source code
def model_profile(self,
                  type=('partial', 'accumulated', 'conditional'),
                  N=300,
                  variables=None,
                  variable_type='numerical',
                  groups=None,
                  span=0.25,
                  grid_points=101,
                  variable_splits=None,
                  variable_splits_type='uniform',
                  center=True,
                  label=None,
                  processes=1,
                  random_state=None,
                  verbose=True):

    """Calculate model-level variable profiles as Partial or Accumulated Dependence

    Parameters
    -----------
    type : {'partial', 'accumulated', 'conditional'}
        Type of model profiles
        (default is `'partial'` for Partial Dependence Profiles).
    N : int, optional
        Number of observations that will be sampled from the `data` attribute before
        the calculation of variable profiles. `None` means all `data` (default is `300`).
    variables : str or array_like of str, optional
        Variables for which the profiles will be calculated
        (default is `None`, which means all of the variables).
    variable_type : {'numerical', 'categorical'}
        Calculate the profiles for numerical or categorical variables
        (default is `'numerical'`).
    groups : str or array_like of str, optional
        Names of categorical variables that will be used for profile grouping
        (default is `None`, which means no grouping).
    span : float, optional
        Smoothing coefficient used as sd for gaussian kernel (default is `0.25`).
    grid_points : int, optional
        Maximum number of points for profile calculations (default is `101`).
        NOTE: The final number of points may be lower than `grid_points`,
        e.g. if there is not enough unique values for a given variable.
    variable_splits : dict of lists, optional
        Split points for variables e.g. `{'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']}`
        (default is `None`, which means that they will be distributed uniformly).
    variable_splits_type : {'uniform', 'quantiles'}, optional
        Way of calculating `variable_splits`. Set 'quantiles' for percentiles.
        (default is `'uniform'`, which means uniform grid of points).
    center : bool, optional
        Theoretically Accumulated Profiles start at `0`, but are centered to compare
        them with Partial Dependence Profiles (default is `True`, which means center
        around the average `y_hat` calculated on the data sample).
    label : str, optional
        Name to appear in result and plots. Overrides default.
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `variables`
        (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).
    verbose : bool, optional
        Print tqdm progress bar (default is `True`).

    Returns
    -----------
    AggregatedProfiles class object
        Explanation object containing the main result attribute and the plot method.

    Notes
    --------
    - https://pbiecek.github.io/ema/partialDependenceProfiles.html
    - https://pbiecek.github.io/ema/accumulatedLocalProfiles.html
    """

    checks.check_data_again(self.data)

    types = ('partial', 'accumulated', 'conditional')
    aliases = {'pdp': 'partial', 'ale': 'accumulated'}
    _type = checks.check_method_type(type, types, aliases)

    if N is None:
        N = self.data.shape[0]
    else:
        N = min(N, self.data.shape[0])

    if random_state is not None:
        np.random.seed(random_state)

    I = np.random.choice(np.arange(N), N, replace=False)

    _ceteris_paribus = CeterisParibus(grid_points=grid_points,
                                     variables=variables,
                                     variable_splits=variable_splits,
                                     variable_splits_type=variable_splits_type,
                                     processes=processes)
    _y = self.y[I] if self.y is not None else self.y
    _ceteris_paribus.fit(self, self.data.iloc[I, :], y=_y, verbose=verbose)

    _model_profile = AggregatedProfiles(
        type=_type,
        variables=variables,
        variable_type=variable_type,
        groups=groups,
        span=span,
        center=center,
        random_state=random_state
    )

    _model_profile.fit(_ceteris_paribus, verbose)

    if label:
        _model_profile.result['_label_'] = label

    return _model_profile
def model_surrogate(self, type=('tree', 'linear'), max_vars=5, max_depth=3, **kwargs)

Create a surrogate interpretable model from the black-box model

This method uses the scikit-learn package to create a surrogate interpretable model (e.g. decision tree) from the black-box model. It aims to use the most important features and add a plot method to the model, so that it can be easily interpreted. See Notes section for references.

Parameters

type : {'tree', 'linear'}
Type of a surrogate model. This can be a decision tree or a linear model (default is 'tree').
max_vars : int, optional
Maximum number of variables that will be used in surrogate model training. These are the most important variables to the black-box model (default is 5).
max_depth : int, optional
The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples (default is 3 for interpretable plot).
kwargs : dict
Keyword arguments passed to one of the: sklearn.tree.DecisionTreeClassifier, sklearn.tree.DecisionTreeRegressor, sklearn.linear_model.LogisticRegression, sklearn.linear_model.LinearRegression

Returns

One of: sklearn.tree.DecisionTreeClassifier, sklearn.tree.DecisionTreeRegressor, sklearn.linear_model.LogisticRegression, sklearn.linear_model.LinearRegression
 
A surrogate model with additional:
  • plot method
  • performance attribute
  • feature_names attribute
  • class_names attribute

Notes

Expand source code
def model_surrogate(self,
                    type=('tree', 'linear'),
                    max_vars=5,
                    max_depth=3,
                    **kwargs):
    """Create a surrogate interpretable model from the black-box model

    This method uses the scikit-learn package to create a surrogate
    interpretable model (e.g. decision tree) from the black-box model.
    It aims to use the most important features and add a plot method to
    the model, so that it can be easily interpreted. See Notes section
    for references.

    Parameters
    -----------
    type : {'tree', 'linear'}
        Type of a surrogate model. This can be a decision tree or a linear model
        (default is `'tree'`).
    max_vars : int, optional
        Maximum number of variables that will be used in surrogate model training.
        These are the most important variables to the black-box model (default is `5`).
    max_depth : int, optional
        The maximum depth of the tree. If `None`, then nodes are expanded until all
        leaves are pure or until all leaves contain less than min_samples_split
        samples (default is `3` for interpretable plot).
    kwargs : dict
        Keyword arguments passed to one of the: `sklearn.tree.DecisionTreeClassifier,
        sklearn.tree.DecisionTreeRegressor, sklearn.linear_model.LogisticRegression,
        sklearn.linear_model.LinearRegression`


    Returns
    -----------
    One of: sklearn.tree.DecisionTreeClassifier, sklearn.tree.DecisionTreeRegressor, sklearn.linear_model.LogisticRegression, sklearn.linear_model.LinearRegression
    A surrogate model with additional:
        - `plot` method
        - `performance` attribute
        - `feature_names` attribute
        - `class_names` attribute

    Notes
    -----------
    - https://christophm.github.io/interpretable-ml-book/global.html
    - https://github.com/scikit-learn/scikit-learn
    """

    _global_checks.global_check_import('scikit-learn', 'surrogate models')
    checks.check_data_again(self.data)

    types = ('tree', 'linear')
    _type = checks.check_method_type(type, types)

    surrogate_model = utils.create_surrogate_model(explainer=self,
                                                   type=_type,
                                                   max_vars=max_vars,
                                                   max_depth=max_depth,
                                                   **kwargs)

    return surrogate_model
def predict(self, data)

Make a prediction

This function uses the predict_function attribute.

Parameters

data : pd.DataFrame, np.ndarray (2d)
Data which will be used to make a prediction.

Returns

np.ndarray (1d)
Model predictions for given data.
Expand source code
def predict(self, data):
    """Make a prediction

    This function uses the `predict_function` attribute.

    Parameters
    ----------
    data : pd.DataFrame, np.ndarray (2d)
        Data which will be used to make a prediction.

    Returns
    ----------
    np.ndarray (1d)
        Model predictions for given `data`.
    """

    checks.check_method_data(data)

    return self.predict_function(self.model, data)
def predict_parts(self, new_observation, type=('break_down_interactions', 'break_down', 'shap', 'shap_wrapper'), order=None, interaction_preference=1, path='average', B=25, keep_distributions=False, label=None, processes=1, random_state=None, **kwargs)

Calculate predict-level variable attributions as Break Down, Shapley Values or Shap Values

Parameters

new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
An observation for which a prediction needs to be explained.
type : {'break_down_interactions', 'break_down', 'shap', 'shap_wrapper'}
Type of variable attributions (default is 'break_down_interactions').
order : list of int or str, optional
Parameter specific for break_down_interactions and break_down. Use a fixed order of variables for attribution calculation. Use integer values or string variable names (default is None, which means order by importance).
interaction_preference : int, optional
Parameter specific for break_down_interactions type. Specify which interactions will be present in an explanation. The larger the integer, the more frequently interactions will be presented (default is 1).
path : list of int, optional
Parameter specific for shap. If specified, then attributions for this path will be plotted (default is 'average', which plots attribution means for B random paths).
B : int, optional
Parameter specific for shap. Number of random paths to calculate variable attributions (default is 25).
keep_distributions :  bool, optional
Save the distribution of partial predictions (default is False).
label : str, optional
Name to appear in result and plots. Overrides default.
processes : int, optional
Parameter specific for shap. Number of parallel processes to use in calculations. Iterated over B (default is 1, which means no parallel computation).
random_state : int, optional
Set seed for random number generator (default is random seed).
kwargs : dict
Used only for 'shap_wrapper'. Pass shap_explainer_type to specify, which Explainer shall be used: {'TreeExplainer', 'DeepExplainer', 'GradientExplainer', 'LinearExplainer', 'KernelExplainer'}<code> (default is </code>None, which automatically chooses an Explainer to use). Also keyword arguments passed to one of the: shap.TreeExplainer.shap_values, shap.DeepExplainer.shap_values, shap.GradientExplainer.shap_values, shap.LinearExplainer.shap_values, shap.KernelExplainer.shap_values. See https://github.com/slundberg/shap

Returns

BreakDown, Shap or ShapWrapper class object
Explanation object containing the main result attribute and the plot method. Object class, its attributes, and the plot method depend on the type parameter.

Notes

Expand source code
def predict_parts(self,
                  new_observation,
                  type=('break_down_interactions', 'break_down', 'shap', 'shap_wrapper'),
                  order=None,
                  interaction_preference=1,
                  path="average",
                  B=25,
                  keep_distributions=False,
                  label=None,
                  processes=1,
                  random_state=None,
                  **kwargs):
    """Calculate predict-level variable attributions as Break Down, Shapley Values or Shap Values

    Parameters
    -----------
    new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
        An observation for which a prediction needs to be explained.
    type : {'break_down_interactions', 'break_down', 'shap', 'shap_wrapper'}
        Type of variable attributions (default is `'break_down_interactions'`).
    order : list of int or str, optional
        Parameter specific for `break_down_interactions` and `break_down`. Use a fixed
        order of variables for attribution calculation. Use integer values  or string
        variable names (default is `None`, which means order by importance).
    interaction_preference : int, optional
        Parameter specific for `break_down_interactions` type. Specify which interactions
        will be present in an explanation. The larger the integer, the more frequently
        interactions will be presented (default is `1`).
    path : list of int, optional
        Parameter specific for `shap`. If specified, then attributions for this path
        will be plotted (default is `'average'`, which plots attribution means for
        `B` random paths).
    B : int, optional
        Parameter specific for `shap`. Number of random paths to calculate
        variable attributions (default is `25`).
    keep_distributions :  bool, optional
        Save the distribution of partial predictions (default is `False`).
    label : str, optional
        Name to appear in result and plots. Overrides default.
    processes : int, optional
        Parameter specific for `shap`. Number of parallel processes to use in calculations.
        Iterated over `B` (default is `1`, which means no parallel computation).
    random_state : int, optional
        Set seed for random number generator (default is random seed).
    kwargs : dict
        Used only for `'shap_wrapper'`. Pass `shap_explainer_type` to specify, which
        Explainer shall be used: `{'TreeExplainer', 'DeepExplainer', 'GradientExplainer',
        'LinearExplainer', 'KernelExplainer'}` (default is `None`, which automatically
        chooses an Explainer to use).
        Also keyword arguments passed to one of the: `shap.TreeExplainer.shap_values,
        shap.DeepExplainer.shap_values, shap.GradientExplainer.shap_values,
        shap.LinearExplainer.shap_values, shap.KernelExplainer.shap_values`.
        See https://github.com/slundberg/shap

    Returns
    -----------
    BreakDown, Shap or ShapWrapper class object
        Explanation object containing the main result attribute and the plot method.
        Object class, its attributes, and the plot method depend on the `type` parameter.

    Notes
    --------
    - https://pbiecek.github.io/ema/breakDown.html
    - https://pbiecek.github.io/ema/iBreakDown.html
    - https://pbiecek.github.io/ema/shapley.html
    - https://github.com/slundberg/shap
    """

    checks.check_data_again(self.data)

    types = ('break_down_interactions', 'break_down', 'shap', 'shap_wrapper')
    _type = checks.check_method_type(type, types)

    if _type == 'break_down_interactions' or _type == 'break_down':
        _predict_parts = BreakDown(
            type=_type,
            keep_distributions=keep_distributions,
            order=order,
            interaction_preference=interaction_preference
        )
    elif _type == 'shap':
        _predict_parts = Shap(
            keep_distributions=keep_distributions,
            path=path,
            B=B,
            processes=processes,
            random_state=random_state
        )
    elif _type == 'shap_wrapper':
        _global_checks.global_check_import('shap', 'SHAP explanations')
        _predict_parts = ShapWrapper('predict_parts')
    else:
        raise TypeError("Wrong type parameter.")

    _predict_parts.fit(self, new_observation, **kwargs)

    if label:
        _predict_parts.result['label'] = label

    return _predict_parts
def predict_profile(self, new_observation, type=('ceteris_paribus',), y=None, variables=None, grid_points=101, variable_splits=None, variable_splits_type='uniform', variable_splits_with_obs=True, processes=1, label=None, verbose=True)

Calculate predict-level variable profiles as Ceteris Paribus

Parameters

new_observation : pd.DataFrame or np.ndarray or pd.Series
Observations for which predictions need to be explained.
type : {'ceteris_paribus', TODO: 'oscilations'}
Type of variable profiles (default is 'ceteris_paribus').
y : pd.Series or np.ndarray (1d), optional
Target variable with the same length as new_observation.
variables : str or array_like of str, optional
Variables for which the profiles will be calculated (default is None, which means all of the variables).
grid_points : int, optional
Maximum number of points for profile calculations (default is 101). NOTE: The final number of points may be lower than grid_points, eg. if there is not enough unique values for a given variable.
variable_splits : dict of lists, optional
Split points for variables e.g. {'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']} (default is None, which means that they will be calculated using one of variable_splits_type and the data attribute).
variable_splits_type : {'uniform', 'quantiles'}, optional
Way of calculating variable_splits. Set 'quantiles' for percentiles. (default is 'uniform', which means uniform grid of points).
variable_splits_with_obs : bool, optional
Add variable values of new_observation data to the variable_splits (default is True).
label : str, optional
Name to appear in result and plots. Overrides default.
processes : int, optional
Number of parallel processes to use in calculations. Iterated over variables (default is 1, which means no parallel computation).
verbose : bool, optional
Print tqdm progress bar (default is True).

Returns

CeterisParibus class object
Explanation object containing the main result attribute and the plot method.

Notes

Expand source code
def predict_profile(self,
                    new_observation,
                    type=('ceteris_paribus',),
                    y=None,
                    variables=None,
                    grid_points=101,
                    variable_splits=None,
                    variable_splits_type='uniform',
                    variable_splits_with_obs=True,
                    processes=1,
                    label=None,
                    verbose=True):
    """Calculate predict-level variable profiles as Ceteris Paribus

    Parameters
    -----------
    new_observation : pd.DataFrame or np.ndarray or pd.Series
        Observations for which predictions need to be explained.
    type : {'ceteris_paribus', TODO: 'oscilations'}
        Type of variable profiles (default is `'ceteris_paribus'`).
    y : pd.Series or np.ndarray (1d), optional
        Target variable with the same length as `new_observation`.
    variables : str or array_like of str, optional
        Variables for which the profiles will be calculated
        (default is `None`, which means all of the variables).
    grid_points : int, optional
        Maximum number of points for profile calculations (default is `101`).
        NOTE: The final number of points may be lower than `grid_points`,
        eg. if there is not enough unique values for a given variable.
    variable_splits : dict of lists, optional
        Split points for variables e.g. `{'x': [0, 0.2, 0.5, 0.8, 1], 'y': ['a', 'b']}`
        (default is `None`, which means that they will be calculated using one of
        `variable_splits_type` and the `data` attribute).
    variable_splits_type : {'uniform', 'quantiles'}, optional
        Way of calculating `variable_splits`. Set `'quantiles'` for percentiles.
        (default is `'uniform'`, which means uniform grid of points).
    variable_splits_with_obs: bool, optional
        Add variable values of `new_observation` data to the `variable_splits`
        (default is `True`).
    label : str, optional
        Name to appear in result and plots. Overrides default.
    processes : int, optional
        Number of parallel processes to use in calculations. Iterated over `variables`
        (default is `1`, which means no parallel computation).
    verbose : bool, optional
        Print tqdm progress bar (default is `True`).

    Returns
    -----------
    CeterisParibus class object
        Explanation object containing the main result attribute and the plot method.

    Notes
    --------
    - https://pbiecek.github.io/ema/ceterisParibus.html
    """

    checks.check_data_again(self.data)

    types = ('ceteris_paribus',)
    _type = checks.check_method_type(type, types)

    if _type == 'ceteris_paribus':
        _predict_profile = CeterisParibus(
            variables=variables,
            grid_points=grid_points,
            variable_splits=variable_splits,
            variable_splits_type=variable_splits_type,
            variable_splits_with_obs=variable_splits_with_obs,
            processes=processes
        )
    else:
        raise TypeError("Wrong type parameter.")

    _predict_profile.fit(self, new_observation, y, verbose)

    if label:
        _predict_profile.result['_label_'] = label

    return _predict_profile
def predict_surrogate(self, new_observation, type='lime', **kwargs)

Wrapper for surrogate model explanations

This function uses the lime package to create the model explanation. See https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

Parameters

new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
An observation for which a prediction needs to be explained.
type : {'lime'}
Type of explanation method (default is 'lime', which uses the lime package to create an explanation).
kwargs : dict
Keyword arguments passed to the lime.lime_tabular.LimeTabularExplainer object and the LimeTabularExplainer.explain_instance method. Exceptions are: training_data, mode, data_row and predict_fn. Other parameters: https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

Returns

lime.explanation.Explanation
Explanation object.

Notes

Expand source code
def predict_surrogate(self,
                      new_observation,
                      type='lime',
                      **kwargs):
    """Wrapper for surrogate model explanations

    This function uses the lime package to create the model explanation.
    See https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

    Parameters
    -----------
    new_observation : pd.Series or np.ndarray (1d) or pd.DataFrame (1,p)
        An observation for which a prediction needs to be explained.
    type : {'lime'}
        Type of explanation method
        (default is `'lime'`, which uses the lime package to create an explanation).
    kwargs : dict
        Keyword arguments passed to the lime.lime_tabular.LimeTabularExplainer object
        and the LimeTabularExplainer.explain_instance method. Exceptions are:
        `training_data`, `mode`, `data_row` and `predict_fn`. Other parameters:
        https://lime-ml.readthedocs.io/en/latest/lime.html#module-lime.lime_tabular

    Returns
    -----------
    lime.explanation.Explanation
        Explanation object.

    Notes
    -----------
    - https://github.com/marcotcr/lime
    """

    checks.check_data_again(self.data)

    if type == 'lime':
        _global_checks.global_check_import('lime', 'LIME explanations')
        _new_observation = checks.check_new_observation_lime(new_observation)
        _explanation = utils.create_lime_explanation(self, _new_observation, **kwargs)
    else:
        raise TypeError("Wrong 'type' parameter.")

    return _explanation
def residual(self, data, y)

Calculate residuals

This function uses the residual_function attribute.

Parameters

data : pd.DataFrame
Data which will be used to calculate residuals.
y : pd.Series or np.ndarray (1d)
Target variable which will be used to calculate residuals.

Returns

np.ndarray (1d)
Model residuals for given data and y.
Expand source code
def residual(self, data, y):
    """Calculate residuals

    This function uses the `residual_function` attribute.

    Parameters
    -----------
    data : pd.DataFrame
        Data which will be used to calculate residuals.
    y : pd.Series or np.ndarray (1d)
        Target variable which will be used to calculate residuals.

    Returns
    -----------
    np.ndarray (1d)
        Model residuals for given `data` and `y`.
    """

    checks.check_method_data(data)

    return self.residual_function(self.model, data, y)