Module dalex.arena
Expand source code
from .object import Arena
__all__ = [
"Arena"
]
Sub-modules
dalex.arena.object
dalex.arena.params
dalex.arena.plots
dalex.arena.server
dalex.arena.static
Classes
class Arena (precalculate=False, enable_attributes=True, enable_custom_params=True)
-
Creates Arena object
This class should be used to create Python connector for Arena dashboard. Initialized object can work both in static and live mode. Use
push_*
methods to add your models, observations and datasets.Parameters
precalculate
:bool
- Enables precalculating plots after using each
push_*
method. enable_attributes
:bool
- Enables attributes of observations and variables. Attributes are required to display details of observation in Arena, but it also increases generated file size.
enable_custom_params
:bool
- Enables modififying observations in dashboard. It requires attributes and works only in live version.
Attributes
models
:list
ofModelParam objects
- List of pushed models encapsulated in ModelParam class
observations
:list
ofObservationParam objects
- List of pushed observations encapsulated in ObservationParam class
datasets
:list
ofDatasetParam objects
- List of pushed datasets encapsulated in DatasetParam class
variables_cache
:list
ofVariableParam objects
- Cached list of VariableParam objects generated using pushed models and datasets
server_thread
:threading.Thread
- Thread of running server or None otherwise
precalculate
:bool
- if plots should be precalculated
enable_attributes
:bool
- if attributes are enabled
enable_custom_params
:bool
- if modifying observations is enabled
timestamp
:float
- timestamp of last modification
cache
:list
ofPlotContainer objects
- List of already calculated plots
mutex
:_thread.lock
- Mutex for params and cache
plots
:list
ofclasses extending PlotContainer
- List of enabled plots
options
:dict
- Options for plots
Notes
For tutorial look at https://arena.drwhy.ai/docs/guide/first-datasource
Expand source code
class Arena: """ Creates Arena object This class should be used to create Python connector for Arena dashboard. Initialized object can work both in static and live mode. Use `push_*` methods to add your models, observations and datasets. Parameters ---------- precalculate : bool Enables precalculating plots after using each `push_*` method. enable_attributes : bool Enables attributes of observations and variables. Attributes are required to display details of observation in Arena, but it also increases generated file size. enable_custom_params : bool Enables modififying observations in dashboard. It requires attributes and works only in live version. Attributes -------- models : list of ModelParam objects List of pushed models encapsulated in ModelParam class observations : list of ObservationParam objects List of pushed observations encapsulated in ObservationParam class datasets : list of DatasetParam objects List of pushed datasets encapsulated in DatasetParam class variables_cache : list of VariableParam objects Cached list of VariableParam objects generated using pushed models and datasets server_thread : threading.Thread Thread of running server or None otherwise precalculate : bool if plots should be precalculated enable_attributes : bool if attributes are enabled enable_custom_params : bool if modifying observations is enabled timestamp : float timestamp of last modification cache : list of PlotContainer objects List of already calculated plots mutex : _thread.lock Mutex for params and cache plots : list of classes extending PlotContainer List of enabled plots options : dict Options for plots Notes -------- For tutorial look at https://arena.drwhy.ai/docs/guide/first-datasource """ def __init__(self, precalculate=False, enable_attributes=True, enable_custom_params=True): self.models = [] self.observations = [] self.datasets = [] self.variables_cache = [] self.server_thread = None self.precalculate = bool(precalculate) self.enable_attributes = bool(enable_attributes) self.enable_custom_params = bool(enable_custom_params) self.timestamp = datetime.timestamp(datetime.now()) self.cache = [] self.mutex = threading.Lock() self.plots = [ ShapleyValuesContainer, FeatureImportanceContainer, PartialDependenceContainer, AccumulatedDependenceContainer, CeterisParibusContainer, BreakDownContainer, MetricsContainer, ROCContainer, FairnessCheckContainer ] self.options = {} for plot in self.plots: options = {} for o in plot.options.keys(): options[o] = plot.options.get(o).get('default') self.options[plot.info.get('plotType')] = options def get_supported_plots(self): """Returns plots classes that can produce at least one valid chart for this Arena. Returns ----------- List of classes extending PlotContainer """ return [plot for plot in self.plots if plot.test_arena(self)] def run_server(self, host='127.0.0.1', port=8181, append_data=False, arena_url='https://arena.drwhy.ai/', disable_logs=True): """Starts server for live mode of Arena Parameters ----------- host : str ip or hostname for the server port : int port number for the server append_data : bool if generated link should append data to already existing Arena window. arena_url : str URl of Arena dhasboard disable_logs : str if logs should be muted Notes -------- Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts Returns ----------- Link to Arena """ if self.server_thread: raise Exception('Server is already running. To stop ip use arena.stop_server().') global_check_import('flask') global_check_import('flask_cors') global_check_import('requests') self.server_thread = threading.Thread(target=start_server, args=(self, host, port, disable_logs)) self.server_thread.start() if append_data: print(arena_url + '?append=http://' + host + ':' + str(port) + '/') else: print(arena_url + '?data=http://' + host + ':' + str(port) + '/') def stop_server(self): """Stops running server""" if not self.server_thread: raise Exception('Server is not running') self._stop_server() self.server_thread.join() self.server_thread = None def update_timestamp(self): """Updates timestamp Notes ------- This function must be called from mutex context """ now = datetime.now() self.timestamp = datetime.timestamp(now) def clear_cache(self, plot_type=None): """Clears cache Parameters ----------- plot_type : str or None If None all cache is cleared. Otherwise only plots with provided plot_type are removed. Notes ------- This function must be called from mutex context """ if plot_type is None: self.cache = [] else: self.cache = list(filter(lambda p: p.plot_type != plot_type, self.cache)) self.update_timestamp() def find_in_cache(self, plot_type, params): """Function searches for cached plot Parameters ----------- plot_type : str Value of plot_type field, that requested plot must have params : dict Keys of this dict are params types (model, observation, variable, dataset) and values are corresponding params labels. Requested plot must have equal params field. Returns -------- PlotContainer or None """ _filter = lambda p: p.plot_type == plot_type and params == p.params with self.mutex: return next(filter(_filter, self.cache), None) def put_to_cache(self, plot_container): """Puts new plot to cache Parameters ----------- plot_container : PlotContainer """ if not isinstance(plot_container, PlotContainer): raise Exception('Invalid plot container') with self.mutex: self.cache.append(plot_container) def fill_cache(self, fixed_params={}): """Generates all available plots and cache them This function tries to generate all plots that are not cached already and put them to cache. Range of generated plots can be narrow using `fixed_params` Parameters ----------- fixed_params : dict This dict specifies which plots should be generated. Only those with all keys from `fixed_params` present and having the same value will be calculated. """ if not isinstance(fixed_params, dict): raise Exception('Params argument must be a dict') for plot_class in self.get_supported_plots(): required_params = plot_class.info.get('requiredParams') # Test if all params fixed by user are used in this plot. If not, then skip it. # This list contains fixed params' types, that are not required by plot. # Loop will be skipped if this list is not empty. if len([k for k in fixed_params.keys() if not k in required_params]) > 0: continue available_params = self.get_available_params() iteration_pools = map(lambda p: available_params.get(p) if fixed_params.get(p) is None else [fixed_params.get(p)], required_params) combinations = [[]] for pool in iteration_pools: combinations = [x + [y] for x in combinations for y in pool] for params_values in combinations: params = dict(zip(required_params, params_values)) self.get_plot(plot_type=plot_class.info.get('plotType'), params_values=params) def push_model(self, explainer, precalculate=None): """Adds model to Arena This method encapsulate explainer in ModelParam object and save appends models fields. When precalculation is enabled triggers filling cache. Parameters ----------- explainer : dalex.Explainer Explainer created using dalex package precalculate : bool or None Overrides constructor `precalculate` parameter when it is not None. If true, then only plots using this model will be precalculated. """ if not isinstance(explainer, Explainer): raise Exception('Invalid Explainer argument') if explainer.label in self.list_params('model'): raise Exception('Explainer with the same label was already added') precalculate = self.precalculate if precalculate is None else bool(precalculate) param = ModelParam(explainer) with self.mutex: self.update_timestamp() self.models.append(param) self.variables_cache = [] if precalculate: self.fill_cache({'model': param}) def push_observations(self, observations, precalculate=None): """Adds observations to Arena Pushed observations will be used to local explainations. Function creates ObservationParam object for each row of pushed dataset. Label for each observation is taken from row name. When precalculation is enabled triggers filling cache. Parameters ----------- observations : pandas.DataFrame Data frame of observations to be explained using instance level plots. Label for each observation is taken from row name. precalculate : bool or None Overrides constructor `precalculate` parameter when it is not None. If true, then only plots using thease observations will be precalculated. """ if not isinstance(observations, DataFrame): raise Exception('Observations argument is not a pandas DataFrame') if len(observations.index.names) != 1: raise Exception('Observations argument need to have only one index') if not observations.index.is_unique: raise Exception('Observations argument need to have unique indexes') precalculate = self.precalculate if precalculate is None else bool(precalculate) old_observations = self.list_params('observation') observations = observations.set_index(observations.index.astype(str)) params_objects = [] for x in observations.index: if x in old_observations: raise Exception('Indexes of observations need to be unique across all observations') params_objects.append(ObservationParam(dataset=observations, index=x)) with self.mutex: self.update_timestamp() self.observations.extend(params_objects) if precalculate: for obs in params_objects: self.fill_cache({'observation': obs}) def push_dataset(self, dataset, target, label, precalculate=None): """Adds dataset to Arena Pushed dataset will visualised using exploratory data analysis plots. Function creates DatasetParam object with specified label and target name. When precalculation is enabled triggers filling cache. Parameters ----------- dataset : pandas.DataFrame Data frame to be visualised using EDA plots. This dataset should contain target variable. target : str Name of target column label : str Label for this dataset precalculate : bool or None Overrides constructor `precalculate` parameter when it is not None. If true, then only plots using this model will be precalculated. """ if not isinstance(dataset, DataFrame): raise Exception('Dataset argument is not a pandas DataFrame') if len(dataset.columns.names) != 1: raise Exception('Dataset argument need to have only one level column names') precalculate = self.precalculate if precalculate is None else bool(precalculate) target = str(target) if target not in dataset.columns: raise Exception('Target is not a column from dataset') if (not isinstance(label, str)) or (len(label) == 0): raise Exception('Label need to be at least one letter') if label in self.list_params('dataset'): raise Exception('Labels need to be unique') param = DatasetParam(dataset=dataset, label=label, target=target) with self.mutex: self.update_timestamp() self.datasets.append(param) self.variables_cache = [] if precalculate: self.fill_cache({'dataset': param}) def get_params(self, param_type): """Returns list of available params Parameters ----------- param_type : str One of ['model', 'variable', 'observation', 'dataset']. Params of this type will be returned Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- List of Param objects """ if param_type == 'observation': with self.mutex: return self.observations elif param_type == 'variable': with self.mutex: if not self.variables_cache: # Extract column names from every dataset in self.dataset list and flatten it result_datasets = [col for dataset in self.datasets for col in dataset.variables] # Extract column names from every model in self.models list and flatten it result_explainers = [col for model in self.models for col in model.variables] result_str = np.unique(result_datasets + result_explainers).tolist() self.variables_cache = [VariableParam(x) for x in result_str] if self.enable_attributes: for var in self.variables_cache: try: for dataset in self.datasets: if var.variable in dataset.variables: var.update_attributes(dataset.dataset[var.variable]) for model in self.models: if var.variable in model.variables: var.update_attributes(model.explainer.data[var.variable]) except: var.clear_attributes() return self.variables_cache elif param_type == 'model': with self.mutex: return self.models elif param_type == 'dataset': with self.mutex: return self.datasets else: raise Exception('Invalid param type') def list_params(self, param_type): """Returns list of available params's labels Parameters ----------- param_type : str One of ['model', 'variable', 'observation', 'dataset']. Labels of params of this type will be returned Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- List of str """ return [x.get_label() for x in self.get_params(param_type)] def get_available_params(self): """Returns dict containing available params of all types This method collect result of `get_params` method for each param type into a dict. Keys are param types and values are lists of Param objects. Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- dict """ result = {} for param_type in ['model', 'observation', 'variable', 'dataset']: result[param_type] = self.get_params(param_type) return result def list_available_params(self): """Returns dict containing labels of available params of all types This methods collect result of `list_params` for each param type into a dict. Keys are param types and values are list of labels. Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- dict """ result = {} for param_type in ['model', 'observation', 'variable', 'dataset']: result[param_type] = self.list_params(param_type) return result def find_param_value(self, param_type, param_label): """Searches for Param object with specified label Parameters ----------- param_type : str One of ['model', 'variable', 'observation', 'dataset']. param_label : str Label of searched param Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- Param or None """ if param_label is None or not isinstance(param_label, str): return None return next((x for x in self.get_params(param_type) if x.get_label() == param_label), None) def print_options(self, plot_type=None): """Prints available options for plots Parameters ----------- plot_type : str or None When not None, then only options for plots with this plot_type will be printed. Notes -------- List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level """ plot = next((x for x in self.plots if x.info.get('plotType') == plot_type), None) if plot_type is None or plot is None: for plot in self.plots: self.print_options(plot.info.get('plotType')) return print('\n\033[1m' + plot.info.get('plotType') + '\033[0m') print('---------------------------------') for o in plot.options.keys(): option = plot.options.get(o) value = self.options.get(plot_type).get(o) print(o + ': ' + str(value) + ' #' + option.get('desc')) def get_option(self, plot_type, option): """Returns value of specified option Parameters ----------- plot_type : str Type of plot, the option is corresponding to. option : str Name of the option Notes -------- List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level Returns -------- None or value of option """ options = self.options.get(plot_type) if options is None: raise Exception('Invalid plot_type') if not option in options.keys(): return with self.mutex: return self.options.get(plot_type).get(option) def set_option(self, plot_type, option, value): """Sets value for the plot option Parameters ----------- plot_type : str When None, then value will be set for each plot with option of name from `option` argument. Otherwise only for plots with specified type. option : str Name of the option value : * Value to be set Notes -------- List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level """ if plot_type is None: for plot in self.plots: self.set_option(plot.info.get('plotType'), option, value) return options = self.options.get(plot_type) if options is None: raise Exception('Invalid plot_type') if not option in options.keys(): return with self.mutex: self.options.get(plot_type)[option] = value self.clear_cache(plot_type) if self.precalculate: self.fill_cache() def get_plot(self, plot_type, params_values, cache=True): """Returns plot for specified type and params Function serches for plot in cache, when not present creates requested plot and put it to cache. Parameters ----------- plot_type : str Type of plot to be generated params_values : dict Dict for param types as keys and Param objects as values cache : bool If serach for plot in cache and put calculated plot into cache. Returns -------- PlotContainer """ plot_class = next((c for c in self.plots if c.info.get('plotType') == plot_type), None) if plot_class is None: raise Exception('Not supported plot type') plot_type = plot_class.info.get('plotType') required_params_values = {} required_params_labels = {} for p in plot_class.info.get('requiredParams'): if params_values.get(p) is None: raise Exception('Required param is missing') required_params_values[p] = params_values.get(p) required_params_labels[p] = params_values.get(p).get_label() result = self.find_in_cache(plot_type, required_params_labels) if cache else None if result is None: result = plot_class(self).fit(required_params_values) if cache: self.put_to_cache(result) return result def get_params_attributes(self, param_type=None): """Returns attributes for all params When `param_type` is not None, then function returns list of dicts. Each dict represents one of available attribute for specified param type. Field `name` is attribute name and field `values` is mapped list of available params to list of value of attribute. When `param_type` is None, then function returns dict with keys for each param type and values are lists described above. Parameters ----------- param_type : str or None One of ['model', 'variable', 'observation', 'dataset'] or None. Specifies attributes of which params should be returned. Notes -------- Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations Returns -------- dict or list """ if param_type is None: obj = {} for p in ['model', 'observation', 'variable', 'dataset']: obj[p] = self.get_params_attributes(p) return obj if not self.enable_attributes: return [] attrs = Param.get_param_class(param_type).list_attributes(self) array = [] for attr in attrs: array.append({ 'name': attr, 'values': [param.get_attributes().get(attr) for param in self.get_params(param_type)] }) return array def get_param_attributes(self, param_type, param_label): """Returns attributes for one param Function searches for param with specified type and label and returns it's attributes. Parameters ----------- param_type : str One of ['model', 'variable', 'observation', 'dataset']. param_label : str Label of param Notes -------- Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations Returns -------- dict """ if not self.enable_attributes: return {} param_value = self.find_param_value(param_type=param_type, param_label=param_label) if param_value: return param_value.get_attributes() else: return {} def save(self, filename="datasource.json"): """Generate all plots and saves them to JSON file Function generates only not cached plots. Parameters ----------- filename : str Path or filename to output file Notes -------- Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts Returns -------- None """ with open(filename, 'w') as file: file.write(get_json(self)) def upload(self, token=None, arena_url='https://arena.drwhy.ai/', open_browser=True): """Generate all plots and uploads them to GitHub Gist Function generates only not cached plots. If token is not provided then function uses OAuth to open GitHub authorization page. Parameters ----------- token : str or None GitHub personal access token. If token is None, then OAuth is used. arena_url : str Address of Arena dashboard instance open_browser : bool Whether to open Arena after upload. Notes -------- Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts Returns -------- Link to the Arena """ global_check_import('requests') if token is None: global_check_import('flask') global_check_import('flask_cors') token = generate_token() data_url = upload_arena(self, token) url = arena_url + '?data=' + data_url if open_browser: webbrowser.open(url) return url
Methods
def clear_cache(self, plot_type=None)
-
Clears cache
Parameters
plot_type
:str
orNone
- If None all cache is cleared. Otherwise only plots with provided plot_type are removed.
Notes
This function must be called from mutex context
Expand source code
def clear_cache(self, plot_type=None): """Clears cache Parameters ----------- plot_type : str or None If None all cache is cleared. Otherwise only plots with provided plot_type are removed. Notes ------- This function must be called from mutex context """ if plot_type is None: self.cache = [] else: self.cache = list(filter(lambda p: p.plot_type != plot_type, self.cache)) self.update_timestamp()
def fill_cache(self, fixed_params={})
-
Generates all available plots and cache them
This function tries to generate all plots that are not cached already and put them to cache. Range of generated plots can be narrow using
fixed_params
Parameters
fixed_params
:dict
- This dict specifies which plots should be generated. Only those with
all keys from
fixed_params
present and having the same value will be calculated.
Expand source code
def fill_cache(self, fixed_params={}): """Generates all available plots and cache them This function tries to generate all plots that are not cached already and put them to cache. Range of generated plots can be narrow using `fixed_params` Parameters ----------- fixed_params : dict This dict specifies which plots should be generated. Only those with all keys from `fixed_params` present and having the same value will be calculated. """ if not isinstance(fixed_params, dict): raise Exception('Params argument must be a dict') for plot_class in self.get_supported_plots(): required_params = plot_class.info.get('requiredParams') # Test if all params fixed by user are used in this plot. If not, then skip it. # This list contains fixed params' types, that are not required by plot. # Loop will be skipped if this list is not empty. if len([k for k in fixed_params.keys() if not k in required_params]) > 0: continue available_params = self.get_available_params() iteration_pools = map(lambda p: available_params.get(p) if fixed_params.get(p) is None else [fixed_params.get(p)], required_params) combinations = [[]] for pool in iteration_pools: combinations = [x + [y] for x in combinations for y in pool] for params_values in combinations: params = dict(zip(required_params, params_values)) self.get_plot(plot_type=plot_class.info.get('plotType'), params_values=params)
def find_in_cache(self, plot_type, params)
-
Function searches for cached plot
Parameters
plot_type
:str
- Value of plot_type field, that requested plot must have
params
:dict
- Keys of this dict are params types (model, observation, variable, dataset) and values are corresponding params labels. Requested plot must have equal params field.
Returns
PlotContainer
orNone
Expand source code
def find_in_cache(self, plot_type, params): """Function searches for cached plot Parameters ----------- plot_type : str Value of plot_type field, that requested plot must have params : dict Keys of this dict are params types (model, observation, variable, dataset) and values are corresponding params labels. Requested plot must have equal params field. Returns -------- PlotContainer or None """ _filter = lambda p: p.plot_type == plot_type and params == p.params with self.mutex: return next(filter(_filter, self.cache), None)
def find_param_value(self, param_type, param_label)
-
Searches for Param object with specified label
Parameters
param_type
:str
- One of ['model', 'variable', 'observation', 'dataset'].
param_label
:str
- Label of searched param
Notes
Information about params https://arena.drwhy.ai/docs/guide/params
Returns
Param
orNone
Expand source code
def find_param_value(self, param_type, param_label): """Searches for Param object with specified label Parameters ----------- param_type : str One of ['model', 'variable', 'observation', 'dataset']. param_label : str Label of searched param Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- Param or None """ if param_label is None or not isinstance(param_label, str): return None return next((x for x in self.get_params(param_type) if x.get_label() == param_label), None)
def get_available_params(self)
-
Returns dict containing available params of all types
This method collect result of
get_params
method for each param type into a dict. Keys are param types and values are lists of Param objects.Notes
Information about params https://arena.drwhy.ai/docs/guide/params
Returns
dict
Expand source code
def get_available_params(self): """Returns dict containing available params of all types This method collect result of `get_params` method for each param type into a dict. Keys are param types and values are lists of Param objects. Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- dict """ result = {} for param_type in ['model', 'observation', 'variable', 'dataset']: result[param_type] = self.get_params(param_type) return result
def get_option(self, plot_type, option)
-
Returns value of specified option
Parameters
plot_type
:str
- Type of plot, the option is corresponding to.
option
:str
- Name of the option
Notes
List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level
Returns
None
orvalue
ofoption
Expand source code
def get_option(self, plot_type, option): """Returns value of specified option Parameters ----------- plot_type : str Type of plot, the option is corresponding to. option : str Name of the option Notes -------- List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level Returns -------- None or value of option """ options = self.options.get(plot_type) if options is None: raise Exception('Invalid plot_type') if not option in options.keys(): return with self.mutex: return self.options.get(plot_type).get(option)
def get_param_attributes(self, param_type, param_label)
-
Returns attributes for one param
Function searches for param with specified type and label and returns it's attributes.
Parameters
param_type
:str
- One of ['model', 'variable', 'observation', 'dataset'].
param_label
:str
- Label of param
Notes
Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations
Returns
dict
Expand source code
def get_param_attributes(self, param_type, param_label): """Returns attributes for one param Function searches for param with specified type and label and returns it's attributes. Parameters ----------- param_type : str One of ['model', 'variable', 'observation', 'dataset']. param_label : str Label of param Notes -------- Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations Returns -------- dict """ if not self.enable_attributes: return {} param_value = self.find_param_value(param_type=param_type, param_label=param_label) if param_value: return param_value.get_attributes() else: return {}
def get_params(self, param_type)
-
Returns list of available params
Parameters
param_type
:str
- One of ['model', 'variable', 'observation', 'dataset']. Params of this type will be returned
Notes
Information about params https://arena.drwhy.ai/docs/guide/params
Returns
List
ofParam objects
Expand source code
def get_params(self, param_type): """Returns list of available params Parameters ----------- param_type : str One of ['model', 'variable', 'observation', 'dataset']. Params of this type will be returned Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- List of Param objects """ if param_type == 'observation': with self.mutex: return self.observations elif param_type == 'variable': with self.mutex: if not self.variables_cache: # Extract column names from every dataset in self.dataset list and flatten it result_datasets = [col for dataset in self.datasets for col in dataset.variables] # Extract column names from every model in self.models list and flatten it result_explainers = [col for model in self.models for col in model.variables] result_str = np.unique(result_datasets + result_explainers).tolist() self.variables_cache = [VariableParam(x) for x in result_str] if self.enable_attributes: for var in self.variables_cache: try: for dataset in self.datasets: if var.variable in dataset.variables: var.update_attributes(dataset.dataset[var.variable]) for model in self.models: if var.variable in model.variables: var.update_attributes(model.explainer.data[var.variable]) except: var.clear_attributes() return self.variables_cache elif param_type == 'model': with self.mutex: return self.models elif param_type == 'dataset': with self.mutex: return self.datasets else: raise Exception('Invalid param type')
def get_params_attributes(self, param_type=None)
-
Returns attributes for all params
When
param_type
is not None, then function returns list of dicts. Each dict represents one of available attribute for specified param type. Fieldname
is attribute name and fieldvalues
is mapped list of available params to list of value of attribute. Whenparam_type
is None, then function returns dict with keys for each param type and values are lists described above.Parameters
param_type
:str
orNone
- One of ['model', 'variable', 'observation', 'dataset'] or None. Specifies attributes of which params should be returned.
Notes
Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations
Returns
dict
orlist
Expand source code
def get_params_attributes(self, param_type=None): """Returns attributes for all params When `param_type` is not None, then function returns list of dicts. Each dict represents one of available attribute for specified param type. Field `name` is attribute name and field `values` is mapped list of available params to list of value of attribute. When `param_type` is None, then function returns dict with keys for each param type and values are lists described above. Parameters ----------- param_type : str or None One of ['model', 'variable', 'observation', 'dataset'] or None. Specifies attributes of which params should be returned. Notes -------- Attribused are used for dynamicly modifying observations https://arena.drwhy.ai/docs/guide/modifying-observations Returns -------- dict or list """ if param_type is None: obj = {} for p in ['model', 'observation', 'variable', 'dataset']: obj[p] = self.get_params_attributes(p) return obj if not self.enable_attributes: return [] attrs = Param.get_param_class(param_type).list_attributes(self) array = [] for attr in attrs: array.append({ 'name': attr, 'values': [param.get_attributes().get(attr) for param in self.get_params(param_type)] }) return array
def get_plot(self, plot_type, params_values, cache=True)
-
Returns plot for specified type and params
Function serches for plot in cache, when not present creates requested plot and put it to cache.
Parameters
plot_type
:str
- Type of plot to be generated
params_values
:dict
- Dict for param types as keys and Param objects as values
cache
:bool
- If serach for plot in cache and put calculated plot into cache.
Returns
PlotContainer
Expand source code
def get_plot(self, plot_type, params_values, cache=True): """Returns plot for specified type and params Function serches for plot in cache, when not present creates requested plot and put it to cache. Parameters ----------- plot_type : str Type of plot to be generated params_values : dict Dict for param types as keys and Param objects as values cache : bool If serach for plot in cache and put calculated plot into cache. Returns -------- PlotContainer """ plot_class = next((c for c in self.plots if c.info.get('plotType') == plot_type), None) if plot_class is None: raise Exception('Not supported plot type') plot_type = plot_class.info.get('plotType') required_params_values = {} required_params_labels = {} for p in plot_class.info.get('requiredParams'): if params_values.get(p) is None: raise Exception('Required param is missing') required_params_values[p] = params_values.get(p) required_params_labels[p] = params_values.get(p).get_label() result = self.find_in_cache(plot_type, required_params_labels) if cache else None if result is None: result = plot_class(self).fit(required_params_values) if cache: self.put_to_cache(result) return result
def get_supported_plots(self)
-
Returns plots classes that can produce at least one valid chart for this Arena.
Returns
List
ofclasses extending PlotContainer
Expand source code
def get_supported_plots(self): """Returns plots classes that can produce at least one valid chart for this Arena. Returns ----------- List of classes extending PlotContainer """ return [plot for plot in self.plots if plot.test_arena(self)]
def list_available_params(self)
-
Returns dict containing labels of available params of all types
This methods collect result of
list_params
for each param type into a dict. Keys are param types and values are list of labels.Notes
Information about params https://arena.drwhy.ai/docs/guide/params
Returns
dict
Expand source code
def list_available_params(self): """Returns dict containing labels of available params of all types This methods collect result of `list_params` for each param type into a dict. Keys are param types and values are list of labels. Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- dict """ result = {} for param_type in ['model', 'observation', 'variable', 'dataset']: result[param_type] = self.list_params(param_type) return result
def list_params(self, param_type)
-
Returns list of available params's labels
Parameters
param_type
:str
- One of ['model', 'variable', 'observation', 'dataset']. Labels of params of this type will be returned
Notes
Information about params https://arena.drwhy.ai/docs/guide/params
Returns
List
ofstr
Expand source code
def list_params(self, param_type): """Returns list of available params's labels Parameters ----------- param_type : str One of ['model', 'variable', 'observation', 'dataset']. Labels of params of this type will be returned Notes -------- Information about params https://arena.drwhy.ai/docs/guide/params Returns -------- List of str """ return [x.get_label() for x in self.get_params(param_type)]
def print_options(self, plot_type=None)
-
Prints available options for plots
Parameters
plot_type
:str
orNone
- When not None, then only options for plots with this plot_type will be printed.
Notes
List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level
Expand source code
def print_options(self, plot_type=None): """Prints available options for plots Parameters ----------- plot_type : str or None When not None, then only options for plots with this plot_type will be printed. Notes -------- List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level """ plot = next((x for x in self.plots if x.info.get('plotType') == plot_type), None) if plot_type is None or plot is None: for plot in self.plots: self.print_options(plot.info.get('plotType')) return print('\n\033[1m' + plot.info.get('plotType') + '\033[0m') print('---------------------------------') for o in plot.options.keys(): option = plot.options.get(o) value = self.options.get(plot_type).get(o) print(o + ': ' + str(value) + ' #' + option.get('desc'))
def push_dataset(self, dataset, target, label, precalculate=None)
-
Adds dataset to Arena
Pushed dataset will visualised using exploratory data analysis plots. Function creates DatasetParam object with specified label and target name. When precalculation is enabled triggers filling cache.
Parameters
dataset
:pandas.DataFrame
- Data frame to be visualised using EDA plots. This dataset should contain target variable.
target
:str
- Name of target column
label
:str
- Label for this dataset
precalculate
:bool
orNone
- Overrides constructor
precalculate
parameter when it is not None. If true, then only plots using this model will be precalculated.
Expand source code
def push_dataset(self, dataset, target, label, precalculate=None): """Adds dataset to Arena Pushed dataset will visualised using exploratory data analysis plots. Function creates DatasetParam object with specified label and target name. When precalculation is enabled triggers filling cache. Parameters ----------- dataset : pandas.DataFrame Data frame to be visualised using EDA plots. This dataset should contain target variable. target : str Name of target column label : str Label for this dataset precalculate : bool or None Overrides constructor `precalculate` parameter when it is not None. If true, then only plots using this model will be precalculated. """ if not isinstance(dataset, DataFrame): raise Exception('Dataset argument is not a pandas DataFrame') if len(dataset.columns.names) != 1: raise Exception('Dataset argument need to have only one level column names') precalculate = self.precalculate if precalculate is None else bool(precalculate) target = str(target) if target not in dataset.columns: raise Exception('Target is not a column from dataset') if (not isinstance(label, str)) or (len(label) == 0): raise Exception('Label need to be at least one letter') if label in self.list_params('dataset'): raise Exception('Labels need to be unique') param = DatasetParam(dataset=dataset, label=label, target=target) with self.mutex: self.update_timestamp() self.datasets.append(param) self.variables_cache = [] if precalculate: self.fill_cache({'dataset': param})
def push_model(self, explainer, precalculate=None)
-
Adds model to Arena
This method encapsulate explainer in ModelParam object and save appends models fields. When precalculation is enabled triggers filling cache.
Parameters
explainer
:Explainer
- Explainer created using dalex package
precalculate
:bool
orNone
- Overrides constructor
precalculate
parameter when it is not None. If true, then only plots using this model will be precalculated.
Expand source code
def push_model(self, explainer, precalculate=None): """Adds model to Arena This method encapsulate explainer in ModelParam object and save appends models fields. When precalculation is enabled triggers filling cache. Parameters ----------- explainer : dalex.Explainer Explainer created using dalex package precalculate : bool or None Overrides constructor `precalculate` parameter when it is not None. If true, then only plots using this model will be precalculated. """ if not isinstance(explainer, Explainer): raise Exception('Invalid Explainer argument') if explainer.label in self.list_params('model'): raise Exception('Explainer with the same label was already added') precalculate = self.precalculate if precalculate is None else bool(precalculate) param = ModelParam(explainer) with self.mutex: self.update_timestamp() self.models.append(param) self.variables_cache = [] if precalculate: self.fill_cache({'model': param})
def push_observations(self, observations, precalculate=None)
-
Adds observations to Arena
Pushed observations will be used to local explainations. Function creates ObservationParam object for each row of pushed dataset. Label for each observation is taken from row name. When precalculation is enabled triggers filling cache.
Parameters
observations
:pandas.DataFrame
- Data frame of observations to be explained using instance level plots. Label for each observation is taken from row name.
precalculate
:bool
orNone
- Overrides constructor
precalculate
parameter when it is not None. If true, then only plots using thease observations will be precalculated.
Expand source code
def push_observations(self, observations, precalculate=None): """Adds observations to Arena Pushed observations will be used to local explainations. Function creates ObservationParam object for each row of pushed dataset. Label for each observation is taken from row name. When precalculation is enabled triggers filling cache. Parameters ----------- observations : pandas.DataFrame Data frame of observations to be explained using instance level plots. Label for each observation is taken from row name. precalculate : bool or None Overrides constructor `precalculate` parameter when it is not None. If true, then only plots using thease observations will be precalculated. """ if not isinstance(observations, DataFrame): raise Exception('Observations argument is not a pandas DataFrame') if len(observations.index.names) != 1: raise Exception('Observations argument need to have only one index') if not observations.index.is_unique: raise Exception('Observations argument need to have unique indexes') precalculate = self.precalculate if precalculate is None else bool(precalculate) old_observations = self.list_params('observation') observations = observations.set_index(observations.index.astype(str)) params_objects = [] for x in observations.index: if x in old_observations: raise Exception('Indexes of observations need to be unique across all observations') params_objects.append(ObservationParam(dataset=observations, index=x)) with self.mutex: self.update_timestamp() self.observations.extend(params_objects) if precalculate: for obs in params_objects: self.fill_cache({'observation': obs})
def put_to_cache(self, plot_container)
-
Puts new plot to cache
Parameters
plot_container
:PlotContainer
Expand source code
def put_to_cache(self, plot_container): """Puts new plot to cache Parameters ----------- plot_container : PlotContainer """ if not isinstance(plot_container, PlotContainer): raise Exception('Invalid plot container') with self.mutex: self.cache.append(plot_container)
def run_server(self, host='127.0.0.1', port=8181, append_data=False, arena_url='https://arena.drwhy.ai/', disable_logs=True)
-
Starts server for live mode of Arena
Parameters
host
:str
- ip or hostname for the server
port
:int
- port number for the server
append_data
:bool
- if generated link should append data to already existing Arena window.
arena_url
:str
- URl of Arena dhasboard
disable_logs
:str
- if logs should be muted
Notes
Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts
Returns
Link to Arena
Expand source code
def run_server(self, host='127.0.0.1', port=8181, append_data=False, arena_url='https://arena.drwhy.ai/', disable_logs=True): """Starts server for live mode of Arena Parameters ----------- host : str ip or hostname for the server port : int port number for the server append_data : bool if generated link should append data to already existing Arena window. arena_url : str URl of Arena dhasboard disable_logs : str if logs should be muted Notes -------- Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts Returns ----------- Link to Arena """ if self.server_thread: raise Exception('Server is already running. To stop ip use arena.stop_server().') global_check_import('flask') global_check_import('flask_cors') global_check_import('requests') self.server_thread = threading.Thread(target=start_server, args=(self, host, port, disable_logs)) self.server_thread.start() if append_data: print(arena_url + '?append=http://' + host + ':' + str(port) + '/') else: print(arena_url + '?data=http://' + host + ':' + str(port) + '/')
def save(self, filename='datasource.json')
-
Generate all plots and saves them to JSON file
Function generates only not cached plots.
Parameters
filename
:str
- Path or filename to output file
Notes
Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts
Returns
None
Expand source code
def save(self, filename="datasource.json"): """Generate all plots and saves them to JSON file Function generates only not cached plots. Parameters ----------- filename : str Path or filename to output file Notes -------- Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts Returns -------- None """ with open(filename, 'w') as file: file.write(get_json(self))
def set_option(self, plot_type, option, value)
-
Sets value for the plot option
Parameters
plot_type
:str
- When None, then value will be set for each plot with
option of name from
option
argument. Otherwise only for plots with specified type. option
:str
- Name of the option
value
:*
- Value to be set
Notes
List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level
Expand source code
def set_option(self, plot_type, option, value): """Sets value for the plot option Parameters ----------- plot_type : str When None, then value will be set for each plot with option of name from `option` argument. Otherwise only for plots with specified type. option : str Name of the option value : * Value to be set Notes -------- List of plots with described options for each one https://arena.drwhy.ai/docs/guide/observation-level """ if plot_type is None: for plot in self.plots: self.set_option(plot.info.get('plotType'), option, value) return options = self.options.get(plot_type) if options is None: raise Exception('Invalid plot_type') if not option in options.keys(): return with self.mutex: self.options.get(plot_type)[option] = value self.clear_cache(plot_type) if self.precalculate: self.fill_cache()
def stop_server(self)
-
Stops running server
Expand source code
def stop_server(self): """Stops running server""" if not self.server_thread: raise Exception('Server is not running') self._stop_server() self.server_thread.join() self.server_thread = None
def update_timestamp(self)
-
Updates timestamp
Notes
This function must be called from mutex context
Expand source code
def update_timestamp(self): """Updates timestamp Notes ------- This function must be called from mutex context """ now = datetime.now() self.timestamp = datetime.timestamp(now)
def upload(self, token=None, arena_url='https://arena.drwhy.ai/', open_browser=True)
-
Generate all plots and uploads them to GitHub Gist
Function generates only not cached plots. If token is not provided then function uses OAuth to open GitHub authorization page.
Parameters
token
:str
orNone
- GitHub personal access token. If token is None, then OAuth is used.
arena_url
:str
- Address of Arena dashboard instance
open_browser
:bool
- Whether to open Arena after upload.
Notes
Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts
Returns
Link to the Arena
Expand source code
def upload(self, token=None, arena_url='https://arena.drwhy.ai/', open_browser=True): """Generate all plots and uploads them to GitHub Gist Function generates only not cached plots. If token is not provided then function uses OAuth to open GitHub authorization page. Parameters ----------- token : str or None GitHub personal access token. If token is None, then OAuth is used. arena_url : str Address of Arena dashboard instance open_browser : bool Whether to open Arena after upload. Notes -------- Read more about data sources https://arena.drwhy.ai/docs/guide/basic-concepts Returns -------- Link to the Arena """ global_check_import('requests') if token is None: global_check_import('flask') global_check_import('flask_cors') token = generate_token() data_url = upload_arena(self, token) url = arena_url + '?data=' + data_url if open_browser: webbrowser.open(url) return url