In order to have lime
support for your model of choice lime
needs to be
able to get predictions from the model in a standardised way, and it needs to
be able to know whether it is a classification or regression model. For the
former it calls the predict_model()
generic which the user is free to
supply methods for without overriding the standard predict()
method. For
the latter the model must respond to the model_type()
generic.
predict_model(x, newdata, type, ...)
model_type(x, ...)
A model object
The new observations to predict
Either 'raw'
to indicate predicted values, or 'prob'
to
indicate class probabilities
passed on to predict
method
A data.frame in the case of predict_model()
. If type = 'raw'
it
will contain one column named 'Response'
holding the predicted values. If
type = 'prob'
it will contain a column for each of the possible classes
named after the class, each column holding the probability score for class
membership. For model_type()
a character string. Either 'regression'
or
'classification'
is currently supported.
Out of the box, lime
supports the following model objects:
train
from caret
WrappedModel
from mlr
xgb.Booster
from xgboost
H2OModel
from h2o
keras.engine.training.Model
from keras
lda
from MASS (used for low-dependency examples)
If your model is not one of the above you'll need to implement support
yourself. If the model has a predict interface mimicking that of
predict.train()
from caret
, it will be enough to wrap your model in
as_classifier()
/as_regressor()
to gain support. Otherwise you'll need
need to implement a predict_model()
method and potentially a model_type()
method (if the latter is omitted the model should be wrapped in
as_classifier()
/as_regressor()
, everytime it is used in lime()
).
# Example of adding support for lda models (already available in lime)
predict_model.lda <- function(x, newdata, type, ...) {
res <- predict(x, newdata = newdata, ...)
switch(
type,
raw = data.frame(Response = res$class, stringsAsFactors = FALSE),
prob = as.data.frame(res$posterior, check.names = FALSE)
)
}
model_type.lda <- function(x, ...) 'classification'