Skip to contents

TL;DR, you can jump straight into the visuals and application with cheem::run_app(), but we suggest you read the introduction to get situated with the context first.

Introduction

Non-linear models regularly result in more accurate prediction than their linear counterparts. However, the number and complexity of their terms make them more opaque to the interpretability. The our ability to understand how features (variables or predictors) influence predictions is important to a wide range of audiences. Attempts to bring interpretability to such complex models is an important aspect of eXplainable Artificial Intelligence (XAI).

Local explanations are one such tool used in XAI. They attempt to approximate the feature importance in the vicinity of one instance (observation). That is to say that they give an approximation of linear terms at the position of one in-sample or out-of-sample observation.

If the analyst can explore how models lead to bad predictions it can suggest insight into issues of the data or suggest models that may be more robust to misclassified or extreme residuals. An analyst may want to explore the support feature contributions where the explanations makes sense or may be completely unreliable. We purpose this sort of analysis as conducted with interactive graphics in the analysis and R package titled cheem.

Preprocessing

This framework is broadly applicable for any model and compatible local explanation. We will illustrate with xgboost::xgboost() model (xgb) and the tree SHAP local explanation with shapviz::shapviz(). The model attempts to predict housing sales price from 11 predictors for 338 sale events from one neighborhood in the 2018 Ames data.

The first things we need are the prediction and a local explanation (or other embedded space). Here we create a xgb model, create predictions, and find the SHAP values of each observation.

## Download if not installed
if(!require(cheem))    install.packages("cheem", dependencies = TRUE)
if(!require(treeshap)) install.packages("treeshap", dependencies = TRUE)
if(!require(shapviz))  install.packages("shapviz", dependencies = TRUE)
## Load onto session
library(cheem)
library(xgboost)
library(shapviz)

## Setup
X    <- amesHousing2018_NorthAmes[, 1:9]
Y    <- amesHousing2018_NorthAmes$SalePrice
clas <- amesHousing2018_NorthAmes$SubclassMS

## Model and predict
ames_train    <- data.matrix(X) %>% xgb.DMatrix(label = Y)
ames_xgb_fit  <- xgboost(data = ames_train, max.depth = 3, nrounds = 25)
ames_xgb_pred <- predict(ames_xgb_fit, newdata = ames_train)
ames_xgb_pred %>% head()

## SHAP values
shp <- shapviz(ames_xgb_fit, X_pred = ames_train, X = X)
## Keep just the [n, p] local explanations
ames_xgb_shap <- shp$S
ames_xgb_shap %>% head()

Note that the choice of the model, prediction, and local explanation (or other embedding) is choice of the analyst and not facilitated by cheem. Now let’s prepare for the visualization of these spaces with a cheem::cheem_ls() call before we start our analysis.

## Preprocessing for cheem analysis
ames_chm <- cheem_ls(X, Y,
                     class      = clas,
                     attr_df    = ames_xgb_shap,
                     pred       = ames_xgb_pred,
                     label      = "Ames, xgb, shap")
names(ames_chm)

Cheem viewer

We have extracted tree SHAP, an feature importance measure in the vicinity of each observation. We need to identify an instance of interest to explore; we do so with the linked brushing available in the global view. Then we will vary contributions from different features to test the support an explanation in a radial tour

Global view

To get more complete view lets look at approximations of the data space, attribution space, and model fits side-by-side with linked brushing with the help of plotly and crosstalk. We have identified an observation with a large Mahalanobis distance (in data space) and the closest neighbor in attribution space.

prim <- 1
comp <- 17
global_view(ames_chm, primary_obs = prim, comparison_obs = comp,
            height_px = 240, width_px = 720,
            as_ggplot = TRUE, color = "log_maha.data")

From this global view we want to identify a primary instance (PI) and optionally a comparison instance (CI) to explore. Misclassified or observations with high residuals are good targets for further exploration. One point sticks out in this case. Instance 243 (shown as *) is a Gentoo (purple) penguin, while the model predict it to be a Chinstrap penguin. Penguin 169 (shown as x) is reasonably close by and correctly predicted as Gentoo. In practice we used linked brushing and misclassification information to guide our search.

Radial tour

There is a lot to unpack here. The normalized distribution of all feature attribution from all instances are shown as parallel coordinates lines. The above selected PI and CI are shown here as a dashed and dotted line respectively. The first thing we notice is that the attribution of the PI is close to it’s (incorrect) prediction of Chinstrap (orange) in terms of bill length (bl) and flipper length (fl). In terms of bill depth and body mass (bd and bm) it is more like its observed species Gentoo (purple). We select flipper length as the feature to manipulate.

## Normalized attribution basis of the PI
bas <- sug_basis(ames_xgb_shap, rownum = prim)
## Default feature to manipulate:
#### the feature with largest separation between PI and CI attribution
mv  <- sug_manip_var(
  ames_xgb_shap, primary_obs = prim, comparison_obs = comp)
## Make the radial tour
ggt <- radial_cheem_tour(
  ames_chm, basis = bas, manip_var = mv,
  primary_obs = prim, comparison_obs = comp, angle = .15)

## Animate it
animate_gganimate(ggt, fps = 6)
  #height = 2, width = 4.5, units = "in", res = 150
## Or as a plotly html widget
#animate_plotly(ggt, fps = 6)

Starting from the attribution projection, this instance already looks more like its observed Gentoo than predicted Chinstrap. However, by frame 8, the basis has a full contribution of flipper length and does look more like the predicted Chinstrap. Looking at the parallel coordinate lines on the basis visual we can see that flipper length has a large gap between PI and CI, lets check the original variables to digest.

library(ggplot2)
prim <- 1

ggplot(penguins_na.rm, aes(x = bill_length_mm,
                           y = flipper_length_mm,
                           colour = species,
                           shape = species)) +
  geom_point() +
  ## Highlight PI, *
  geom_point(data = penguins_na.rm[prim, ],
             shape = 8, size = 5, alpha = 0.8) +
  ## Theme, scaling, color, and labels
  theme_bw() +
  theme(aspect.ratio = 1) +
  scale_color_brewer(palette = "Dark2") +
  labs(y = "Flipper length [mm]", x = "Bill length [mm]",
       color = "Observed species", shape = "Observed species")

This profile, with two features that are most distinguished between the PI and CI. This instance is nested in the in between the Chinstrap penguins. That makes this instance particularly hard for a random forest model to classify as decision tree can only make partition on one value (horizontal and vertical lines here).

Shiny application

We provide an interactive shiny application. Interactive features are made possible with plotly, crosstalk, and DT. We have preprocessed simulated and modern datasets for you to explore this analysis with. Alternatively, bring your own data by saving the return of cheem_ls() as an rds file. Follow along with the example in ?cheem_ls.

Conclusion

Interpretability of black-box models is important to maintain. Local explanation extend this interpretability by approximating the feature importance in the vicinity of one instance. We purpose post-hoc analysis of these local explanations. First we explore them in a global, full instance context. Then we explore the support of the local explanation to see where it seems plausible or unreliable.

Other local explanations (& models)

cheem is agnostic to model or local explanation, but requires a model and local explanation. Above we illustrated using a random forest to predict penguin species. Below demonstrates using other attribution spaces from different models.

shapviz (& xgb classification)

shapviz is being actively maintained and is hosted on CRAN. It is compatible with H2O, lgb, and xgb models.

https://github.com/ModelOriented/shapviz

if(!require(shapviz)) install.packages("shapviz")
if(!require(xgboost)) install.packages("xgboost")
library(shapviz)
library(xgboost)
set.seed(3653)

## Setup
X    <- spinifex::penguins_na.rm[, 1:4]
Y    <- spinifex::penguins_na.rm$species
clas <- spinifex::penguins_na.rm$species

## Model and predict
peng_train    <- data.matrix(X) %>%
  xgb.DMatrix(label = Y)
peng_xgb_fit  <- xgboost(data = peng_train, max.depth = 3, nrounds = 25)
peng_xgb_pred <- predict(peng_xgb_fit, newdata = peng_train)

## SHAP
peng_xgb_shap <- shapviz(peng_xgb_fit, X_pred = peng_train, X = X)
## Keep just the [n, p] local explanations
peng_xgb_shap <- peng_xgb_shap$S

treeshap (& randomForest regression)

treeshap is only available on CRAN. It is compatible with many tree-based models including gbm, lbm, rf, ranger, and xgb models.

https://github.com/ModelOriented/treeshap

if(!require(treeshap)) install.packages("treeshap")
if(!require(randomForest)) install.packages("randomForest")
library(treeshap)
library(randomForest)

## Setup
X    <- spinifex::wine[, -1:2]
Y    <- spinifex::wine$Alcohol
clas <- spinifex::wine$Type
  
## Fit randomForest::randomForest
wine_rf_fit  <- randomForest::randomForest(
  X, Y, ntree = 125,
  mtry = ifelse(is_discrete(Y), sqrt(ncol(X)), ncol(X) / 3),
  nodesize = max(ifelse(is_discrete(Y), 1, 5), nrow(X) / 500))
wine_rf_pred <- predict(wine_rf_fit)

## treeshap::treeshap()
wine_rf_tshap <- wine_rf_fit %>%
  treeshap::randomForest.unify(X) %>%
  treeshap::treeshap(X, interactions = FALSE, verbose = FALSE)
## Keep just the [n, p] local explanations
wine_rf_tshap <- wine_rf_tshap$shaps

DALEX (& LM regression)

DALEX is a popular and versatile XAI package available on CRAN. It is compatible with many models, but it uses the original, slower variant of SHAP local explanation. Expect long run times for sizable data or complex models.

https://ema.drwhy.ai/shapley.html#SHAPRcode

if(!require(DALEX)) install.packages("DALEX")
library(DALEX)

## Setup
X    <- dragons[, c(1:4, 6)]
Y    <- dragons$life_length
clas <- dragons$colour

## Model and predict
drag_lm_fit  <- lm(data = data.frame(Y, X), Y ~ .)
drag_lm_pred <- predict(drag_lm_fit)

## SHAP via DALEX, versatile but slow
drag_lm_exp <- explain(drag_lm_fit, data = X, y = Y,
                       label = "Dragons, LM, SHAP")
## DALEX::predict_parts_shap is flexible, but slow and one row at a time
drag_lm_shap <- matrix(NA, nrow(X), ncol(X))
sapply(1:nrow(X), function(i){
  pps <- predict_parts_shap(drag_lm_exp, new_observation = X[i, ])
  ## Keep just the [n, p] local explanations
  drag_lm_shap[i, ] <<- tapply(
    pps$contribution, pps$variable, mean, na.rm = TRUE) %>% as.vector()
})
drag_lm_shap <- as.data.frame(drag_lm_shap)