In order to have lime support for your model of choice lime needs to be able to get predictions from the model in a standardised way, and it needs to be able to know whether it is a classification or regression model. For the former it calls the predict_model() generic which the user is free to supply methods for without overriding the standard predict() method. For the latter the model must respond to the model_type() generic.

predict_model(x, newdata, type, ...)

model_type(x, ...)

Arguments

x

A model object

newdata

The new observations to predict

type

Either 'raw' to indicate predicted values, or 'prob' to indicate class probabilities

...

passed on to predict method

Value

A data.frame in the case of predict_model(). If type = 'raw' it will contain one column named 'Response' holding the predicted values. If type = 'prob' it will contain a column for each of the possible classes named after the class, each column holding the probability score for class membership. For model_type() a character string. Either 'regression' or 'classification' is currently supported.

Supported Models

Out of the box, lime supports the following model objects:

  • train from caret

  • WrappedModel from mlr

  • xgb.Booster from xgboost

  • H2OModel from h2o

  • keras.engine.training.Model from keras

  • lda from MASS (used for low-dependency examples)

If your model is not one of the above you'll need to implement support yourself. If the model has a predict interface mimicking that of predict.train() from caret, it will be enough to wrap your model in as_classifier()/as_regressor() to gain support. Otherwise you'll need need to implement a predict_model() method and potentially a model_type() method (if the latter is omitted the model should be wrapped in as_classifier()/as_regressor(), everytime it is used in lime()).

Examples

# Example of adding support for lda models (already available in lime) predict_model.lda <- function(x, newdata, type, ...) { res <- predict(x, newdata = newdata, ...) switch( type, raw = data.frame(Response = res$class, stringsAsFactors = FALSE), prob = as.data.frame(res$posterior, check.names = FALSE) ) } model_type.lda <- function(x, ...) 'classification'